当前位置: 首页 > news >正文

强化学习初探及OREAL实践

任务来源:Docs

文章使用50%A100对课件内容进行复现,同时修正了原始文档中一些对萌新的考验,比如配置文件中GPU设置、bath大小,采样温度、学习率设置等问题。正确的py文件我会直接放到文章中。文章分为两部分,结合个人理解的课程主题内容介绍(针对萌新)以及原始任务复现。
一、本课程主要知识点

  • 强化学习是一种基于环境反馈通过奖励信号优化策略的机器学习方法。OREAL是一个强化学习实验框架,结合现代算法与大模型推理部署,支持从训练到推理的完整流程。通过理解强化学习基本原理和OREAL框架的结构及功能,你能更好地开展强化学习算法实验及项目应用。

  • 为了更好理解实现过程,有必要知道这个过程到底涉及到几个模型,分别是什么用,正常需要假设三个模型才能进行强化学习(其实一个也行,比如使用大模型的时候有时会出现两个答案,使用者可以作为Judger来进行评判哪个答案更好),具体说明如下:

  • 三个模型的工作流程

  • Prompt 输入题目

       │

       ├────────────►【被训练模型】生成回答 A(会被训练优化)

       │

       ├────────────►【参考模型】生成回答 B(参考标准)

       │

       ▼

    【Judger 模型】比较 A 和 B,给 A 打分(作为 Reward)

       │

       ▼

    【PPO 训练器】根据 Reward 更新【被训练模型】参数

  • 三个模型的角色和作用

    模型

    角色

    是否训练

    作用

    你用的版本

    被训练模型

    主角(Policy)

    ✅ 是(用PPO更新)

    生成回答,要学会更好地回答问题

    用的是自定义的 0.5B 模型

    参考模型

    辅助者(Reference)

    ❌ 否

    提供一个“还可以”的回答用于对比

    同样是 0.5B 模型,固定不变

    Judger 模型

    裁判(Reward Model)

    ❌ 否

    比较两个回答(A vs B),给“被训练模型”的回答打分

    用的是 Qwen2.5-Math-1.5B,通过 lmdeploy 启动

    二、原文任务复现

    首先是安装python环境,这里直接用安装conda安装虚拟环境

  • wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
    chmod +x Miniconda3-latest-Linux-x86_64.sh
    ./Miniconda3-latest-Linux-x86_64.sh
    # 后面按提示操作
    conda create -n OREAL python=3.11
    conda activate OREAL

    接着是安装具体依赖(这里我就当你已经安装完CUDA环境了),这个flash-attn需要编译,因此安装比较困难,如果安装失败就建议pip uninstall ninja && pip install -U ninja并重新安装试试。

  • pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
    pip install flash-attn --no-build-isolation
    pip install fire xtuner[all]==0.2.0rc0

    如果你在安装flash-attn的过程中,出现了Build字样,并且花费的时间特别久,就直接去官网下载预编译好的版本自己上传安装吧,别浪费时间编译了,对于按前面的方法搭建的环境来说,直接安装这个包就行。

  • 运行Judger模型

  • 如果你运行git lfs install没有任何问题并显示Git LFS initialized.就可以跳过这步,否则就看下面的步骤。

  • sudo apt update
    sudo apt install git-lfs
    git lfs install

    下载离线模型

  • OREAL 需要一个语言模型作为验证器,结合基于规则的验证函数来评估生成解决方案的正确性。原版的repo使用的是72B的模型,但这个对大家来说开销太大了,因此换个1.5B的模型意思下。

  • git clone https://hf-mirror.com/Qwen/Qwen2.5-Math-1.5B-Instruct
    pip install lmdeploy partial_json_parser
    export HF_ENDPOINT=https://hf-mirror.com
    lmdeploy serve api_server ./Qwen2.5-Math-1.5B-Instruct --chat-template qwen --log-level INFO --server-port 10003 --tp 1 --quant-policy 4 --cache-max-entry-count 0.1

    运行效果如下

  • 运行训练代码
    因为OREAL的强化学习训练方式对资源要求很高,训练个7B的模型得需要32张A100分布式的训练大约9小时,这个成本太高了。因此我对训练代码进行了爆改,在训练的过程中需要使用3个LLM
    1. 用于math_judger的LLM(在下面的示例中,我把这个模型规模从72B偷工减料到了1.5B)
    2. 用于reference的LLM(被我偷工减料到了0.5B)
    3. 被训练的LLM(被我偷工减料到了0.5B)
    这个配置可以在40G显卡的显卡上完成训练(偶尔会OOM),至于训练出的水平就不要在意了,无数先辈复现Deepseek R1-Zero模型时就发现过一个铁律,小于 3B 的模型无法学会推理,只适合使用sft微调而不是强化学习。所以就不要对这个训练出的0.5B模型的性能抱有太大期待,只要能跑通就好。

    一篇Deepseek R1-Zero复现实验:https://mp.weixin.qq.com/s/Z7P61IV3n4XYeC0Et_fvwg

    首先先扒OREAL的源代码

    如果你从github上拉代码有问题,可以用这里面的方法加速访问:https://github.akams.cn/

  • git clone https://github.com/InternLM/OREAL

    这里放3个我修改过的代码,将这些代码替换原代码即可训练。
    config.py 放到./oreal/configs/config.py下
    single_train.py 自己放在代码根目录
    trajectory.py 替换掉./oreal/datasets/trajectory.py

  • config.py

# Model Related Settings
# actor = "internlm/OREAL-7B-SFT"
actor = "Qwen/Qwen2.5-0.5B-Instruct"
reference = actor
token_level_rm = actor
# reference = 'Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4'# Tokenizer related settings
# jinja2 template for hf tokenizer
chat_template = "{% set sys_prompt = \"You are an expert mathematician with extensive experience in mathematical competitions. You approach problems through systematic thinking and rigorous reasoning. When solving problems, follow these thought processes:\\n\\n## Deep Understanding\\nTake time to fully comprehend the problem before attempting a solution. Consider:\\n- What is the real question being asked?\\n- What are the given conditions and what do they tell us?\\n- Are there any special restrictions or assumptions?\\n- Which information is crucial and which is supplementary?\\n\\n## Multi-angle Analysis\\nBefore solving, conduct thorough analysis:\\n- What mathematical concepts and properties are involved?\\n- Can you recall similar classic problems or solution methods?\\n- Would diagrams or tables help visualize the problem?\\n- Are there special cases that need separate consideration?\\n\\n## Systematic Thinking\\nPlan your solution path:\\n- Propose multiple possible approaches\\n- Analyze the feasibility and merits of each method\\n- Choose the most appropriate method and explain why\\n- Break complex problems into smaller, manageable steps\\n\\n## Rigorous Proof\\nDuring the solution process:\\n- Provide solid justification for each step\\n- Include detailed proofs for key conclusions\\n- Pay attention to logical connections\\n- Be vigilant about potential oversights\\n\\n## Repeated Verification\\nAfter completing your solution:\\n- Verify your results satisfy all conditions\\n- Check for overlooked special cases\\n- Consider if the solution can be optimized or simplified\\n- Review your reasoning process\\n\\nRemember:\\n1. Take time to think thoroughly rather than rushing to an answer\\n2. Rigorously prove each key conclusion\\n3. Keep an open mind and try different approaches\\n4. Summarize valuable problem-solving methods\\n5. Maintain healthy skepticism and verify multiple times\\n\\nYour response should reflect deep mathematical understanding and precise logical thinking, making your solution path and reasoning clear to others.\\n\\nWhen you're ready, present your complete solution with:\\n- Clear problem understanding\\n- Detailed solution process\\n- Key insights\\n- Thorough verification\\n\\nFocus on clear, logical progression of ideas and thorough explanation of your mathematical reasoning. Provide answers in the same language as the user asking the question, repeat the final answer using a '\\\\boxed{}' without any units, you have [[8192]] tokens to complete the answer.\" %}{%- if tools %}\n    {{- '<|im_start|>system\\n' }}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- messages[0]['content'] }}\n    {%- else %}\n        {{- sys_prompt }}\n    {%- endif %}\n    {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n    {%- for tool in tools %}\n        {{- \"\\n\" }}\n        {{- tool | tojson }}\n    {%- endfor %}\n    {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n    {%- if messages[0]['role'] == 'system' %}\n        {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n    {%- else %}\n        {{- '<|im_start|>system\\n' ~ sys_prompt ~ '<|im_end|>\\n' }}\n    {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n    {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n        {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n    {%- elif message.role == \"assistant\" %}\n        {{- '<|im_start|>' + message.role }}\n        {%- if message.content %}\n            {{- '\\n' + message.content }}\n        {%- endif %}\n        {%- for tool_call in message.tool_calls %}\n            {%- if tool_call.function is defined %}\n                {%- set tool_call = tool_call.function %}\n            {%- endif %}\n            {{- '\\n<tool_call>\\n{\"name\": \"' }}\n            {{- tool_call.name }}\n            {{- '\", \"arguments\": ' }}\n            {{- tool_call.arguments | tojson }}\n            {{- '}\\n</tool_call>' }}\n        {%- endfor %}\n        {{- '<|im_end|>\\n' }}\n    {%- elif message.role == \"tool\" %}\n        {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n            {{- '<|im_start|>user' }}\n        {%- endif %}\n        {{- '\\n<tool_response>\\n' }}\n        {{- message.content }}\n        {{- '\\n</tool_response>' }}\n        {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n            {{- '<|im_end|>\\n' }}\n        {%- endif %}\n    {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n    {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n"
stop_word = "<|im_end|>"dtype = "auto"
selective_recompute = 1.0
cpu_offload = False  # ===== 修改点:如果显存紧张,可以改成 True,启用CPU显存卸载 =====# Dataset Related Settings
data_difficulty_balance_cfg = [# pass rate range, repeat times((0.0, 0.2), 6),((0.2, 0.4), 4),((0.4, 0.6), 4),((0.6, 0.8), 2),
]
datasets = "internlm/OREAL-RL-Prompts"num_workers = 0  # ===== 修改点:减少DataLoader并行线程,节省显存 =====# Generate Related Settings
gen_max_new = 1024  # ===== 修改点:从2048降至1024,降低显存占用 =====
gen_max_length = 2048  # ===== 修改点:对应调整生成最大长度 =====
gen_top_k = 0  # set to 0 means not use topk sampling
gen_top_p = 0.9
temperature = 0.8  # ===== 修改点:降低采样温度,提升输出稳定性,可能轻微减小显存 =====
gen_do_sample = True
max_prefill_batch = 8  # ===== 修改点:降低预填充批次大小,减少峰值显存 =====gen_global_batch = 1  # ===== 修改点:将batch改为1,显存更友好 =====
prompt_repeat_k = 1  # sample k times for each prompt# Optimizer Related Settings
rl_global_batch = gen_global_batch
rl_mirco_batch = 1  # ===== 修改点:减小微批大小 =====
filter_trajectory = True  # sample one correct and one incorrect trajectory for each prompt
warmup_steps = 10
total_steps = 90
actor_freeze_steps = 10  # freeze actor and only update token level reward model for the first 10 steps
actor_lr = 5e-7
actor_min_lr = 1e-7
token_level_rm_lr = 1e-6  # ===== 修改点:适当降低token level reward model学习率 =====
token_level_rm_lr_min = 4e-7
wd = 0.01  # weight decay
max_grad_norm = 1  # gradient clipping# importance sampling setting with token level reward model
threshold_rescale = True
correct_threshold = 0.5
incorrect_threshold = 0.5
# topk_rescale = True
# correct_topk_ratio = 0.25
# incorrect_topk_ratio = 0.25reward_shaping_type = "rloo"
loss_type = "per_token"
positive_loss_factor = 1.0
negative_loss_factor = 0.5
pos_mult_adv = True
kl_coef = 0.01  # KL coefficient# General Settings
work_dir = "work_dirs"  # directory to save logs and checkpoints
checkpoint_interval = 10  # interval to save checkpoint, <1 means save by proportion, >=1 means save by steps
log_interval = 1  # interval steps for logging
seed = 0  # random seed
debug = False  # set log level to DEBUG# judger related settings
judgers_config = dict(math_judger=dict(  # math judger related settingshosts=["127.0.0.1:10003",],stop_word=stop_word,thinking_finish_words=["<conclude>", "**Final Answer**", "</think>"],num_processes=1,concurrency_per_proc=(1, 1),)
)
data_judger_mapping = dict(math=["math_judger"])

  single_train.py

# Copyright (c) InternLM. All rights reserved.
import json
import os
import sys
import time
from collections import OrderedDict
from datetime import datetime, timedeltaimport fire
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizerfrom oreal.datasets import (InferDataset,OrealPromptDataset,PromptCollator,TrajectoryCollator,TrajectoryDataset,TrajectoryDatasetWithFilter,
)
from oreal.judgers import ParallelRouter
from oreal.utils import Config
from xtuner._lite.algorithms.sft import SftCollatorlogger = Nonedef is_interval(step, total_steps, interval):return (step + 1) % interval == 0 or (step + 1) == total_stepsdef train_oreal(cfg_path, **kwargs):args = Config.fromfile(cfg_path)args.update(kwargs)############################################################################                           1. Environment                                ############################################################################rank = 0  # Single machine ranktimestamp = datetime.now().strftime("%Y%m%d%H%M%S")args.work_dir = os.path.join(args.work_dir, timestamp)os.makedirs(args.work_dir, exist_ok=True)log_file = os.path.join(args.work_dir, f"rank{rank}.log")class SimpleLogger:@staticmethoddef info(message):print(message)@staticmethoddef warning(message):print(f"WARNING: {message}")@staticmethoddef error(message):print(f"ERROR: {message}")@staticmethoddef success(message):print(f"SUCCESS: {message}")@staticmethoddef debug(message):print(f"DEBUG: {message}")global loggerlogger = SimpleLogger()logger.info(args)logger.info(f"Work directory: {args.work_dir}")# -------------------    Environment  End  ------------------------------ #############################################################################                     2. Dataset & Dataloader                             ############################################################################tokenizer = AutoTokenizer.from_pretrained(args.actor, trust_remote_code=True, padding_side="right")if args.chat_template is not None:logger.info(f"[CHAT_TEMPLATE] {args.chat_template}")tokenizer.chat_template = args.chat_templateprompt_dataset = OrealPromptDataset(args.datasets,tokenizer,difficulty_balance_cfg=args.data_difficulty_balance_cfg,)prompt_collator = PromptCollator(pack_batch=False)prompt_dataloader = DataLoader(prompt_dataset,batch_size=args.gen_global_batch // args.prompt_repeat_k,num_workers=args.num_workers,shuffle=True,collate_fn=prompt_collator,persistent_workers=args.num_workers > 0,)logger.info(f"[Dataset] {len(prompt_dataset)} prompts.")logger.info(f"[Dataloader] {len(prompt_dataloader)} batches.")# -------------------    Dataset & Dataloader  End  --------------------- ## ---------------------    Router  Start  ------------------------------- #judger_router = ParallelRouter(judgers_config=args.judgers_config,data_judger_mapping=args.data_judger_mapping,logger=logger,)stop_token_ids = []word_ids = tokenizer.encode(args.stop_word, add_special_tokens=False)if len(word_ids) > 1:raise NotImplementedError("The stop word must be a single token.")stop_token_ids.append(word_ids[0])############################################################################                      4. Optimizer & Scheduler                           ############################################################################dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 if args.dtype == "bf16" else torch.float32actor_model = AutoModelForCausalLM.from_pretrained(args.actor, torch_dtype=dtype, device_map="auto")# ref_model = AutoModelForCausalLM.from_pretrained(args.reference, torch_dtype=dtype, device_map="cuda:1")#ref_model = AutoModelForCausalLM.from_pretrained(args.reference, torch_dtype="auto", device_map="cuda:1")actor_model = torch.compile(actor_model)#ref_model = torch.compile(ref_model)ref_model = actor_modelactor_model.train()ref_model.eval()actor_params = [p for p in actor_model.parameters() if p.requires_grad]actor_optimizer = AdamW(actor_params, lr=args.actor_lr, weight_decay=args.wd)total_steps = args.total_stepswarmup_steps = args.warmup_stepsdef warmup_fn(x):return x / warmup_steps if x < warmup_steps else 1warmup_scheduler = LambdaLR(actor_optimizer, warmup_fn)cosine_scheduler = CosineAnnealingLR(actor_optimizer, T_max=total_steps - warmup_steps, eta_min=args.actor_min_lr)# ----------------    Optimizer & Scheduler End   ----------------------- #############################################################################                          5. Training                                    ############################################################################trajectory_dataset = TrajectoryDataset()prompt_iterator = iter(prompt_dataloader)start_step = 0start_train_t = time.time()logger.info("[Train] Begin Train Loop")for step in range(start_step, total_steps):if step <= warmup_steps:warmup_scheduler.step()cur_lr = warmup_scheduler.get_last_lr()[0]else:cosine_scheduler.step()cur_lr = cosine_scheduler.get_last_lr()[0]step_start_t = time.time()step_rl_loss = 0step_kl_penalty_loss = 0step_token_level_rm_loss = 0data = next(prompt_iterator)prompt_input_ids = data["input_ids"]prompt_input_ids_cuda = prompt_input_ids.to('cuda:0')# Stage 1,  Actor Model Generationstep_gen_start_t = time.time()actor_model.eval()responses = actor_model.generate(prompt_input_ids_cuda,max_length=args.gen_max_length,max_new_tokens=args.gen_max_new,do_sample=args.gen_do_sample,top_k=args.gen_top_k,top_p=args.gen_top_p,temperature=args.temperature,)responses = responses.to('cpu')actor_model.train()step_gen_time = time.time() - step_gen_start_tresponse_texts = [tokenizer.decode(res, skip_special_tokens=False) for res in responses]# Stage 2,  Inferstep_infer_start_t = time.time()infer_dataset = InferDataset(prompt_input_ids,responses,data["message_data"],data["metadata"],)infer_dataloader = DataLoader(infer_dataset,batch_size=args.rl_mirco_batch,num_workers=0,collate_fn=SftCollator(pack_batch=False),shuffle=False,persistent_workers=False,)policies = []for infer_packed_seq in infer_dataloader:infer_input_ids = infer_packed_seq["input_ids"]infer_labels = infer_packed_seq["labels"]infer_input_ids_cuda = infer_input_ids.to('cuda:0')with torch.no_grad():actor_logits = actor_model(infer_input_ids_cuda).logitsactor_logits = actor_logits.to('cpu')policies.extend([{"input_ids": input_ids.tolist(),"labels": labels.tolist(),"num_tokens": len(labels),"sequence_text": tokenizer.decode(input_ids, skip_special_tokens=False)}for input_ids, labels in zip(infer_input_ids, infer_labels)])step_infer_time = time.time() - step_infer_start_t# Get Judger Rewardif len(policies) > 0:judger_rewards = [0.0] * len(policies)  # Replace with actual judger rewards if neededfor i, policy in enumerate(policies):policy["judger_reward"] = judger_rewards[i]policy["judger_advantage"] = ''# Stage 4, RLstep_rl_start_t = time.time()trajectory_dataset.update(policies)rl_loader = DataLoader(trajectory_dataset,batch_size=args.rl_mirco_batch,num_workers=0,collate_fn=TrajectoryCollator(pack_batch=False),shuffle=False,persistent_workers=False,)step_rl_loss = 0step_kl_penalty_loss = 0step_token_level_rm_loss = 0for packed_policy in rl_loader:input_ids = packed_policy["input_ids"]input_ids_cuda = packed_policy["input_ids"].to('cuda:0')labels = packed_policy["labels"]judgerAdvantages = packed_policy["judger_advantages"]outputs = actor_model(input_ids_cuda, labels=labels)loss = outputs.loss# Add KL penaltywith torch.no_grad():ref_outputs = ref_model(input_ids_cuda, labels=labels)ref_logits = ref_outputs.logitsref_logits_cpu = ref_logits.cpu()outputs_logits_cpu = outputs.logits.cpu()kl_loss = torch.nn.functional.kl_div(torch.nn.functional.log_softmax(outputs.logits, dim=-1),torch.nn.functional.softmax(ref_logits, dim=-1),reduction="batchmean")total_loss = loss + args.kl_coef * kl_losstotal_loss.backward()step_rl_loss += loss.item()step_kl_penalty_loss += kl_loss.item()actor_optimizer.step()actor_optimizer.zero_grad()step_rl_time = time.time() - step_rl_start_tstep_time = time.time() - step_start_teta = step_time * (total_steps - step)eta = timedelta(seconds=int(eta))logger.info("[Train] Step "f"{step + 1}/{total_steps}  "f"actor_lr: {cur_lr:.3e}  "f"rl_loss: {step_rl_loss:.3f}  "f"kl_penalty_loss: {step_kl_penalty_loss:.3f}  "f"total_time: {step_time:.2f}s  "f"eta: {eta}")if is_interval(step, total_steps, args.checkpoint_interval):save_path = os.path.join(args.work_dir, f"checkpoint_{step+1}")actor_model.save_pretrained(save_path)tokenizer.save_pretrained(save_path)train_cost_time = time.time() - start_train_tlogger.success(f"[Train] Cost {timedelta(seconds=int(train_cost_time))}")if __name__ == "__main__":fire.Fire(train_oreal)

trajectory.py

# Copyright (c) InternLM. All rights reserved.
import json
import randomimport numpy as np
import torch
from xtuner._lite import get_logger
from xtuner._lite.algorithms.sft.dataset import SftCollatorlogger = get_logger()class InferDataset(torch.utils.data.Dataset):def __init__(self, prompts_input_ids, responses_ids, message_data, metadata):super().__init__()assert (len(prompts_input_ids)== len(responses_ids)== len(message_data)== len(metadata)), f"The length of prompts_input_ids, responses_ids, message_data, metadata should be the same, but got {len(prompts_input_ids)}, {len(responses_ids)}, {len(message_data)}, {len(metadata)}"self.prompts_input_ids = prompts_input_idsself.responses_ids = responses_idsself.message_data = message_dataself.metadata = metadatadef __len__(self):return len(self.prompts_input_ids)def __getitem__(self, item):prompt_input_ids = self.prompts_input_ids[item]response_ids = self.responses_ids[item]num_prefill_tokens = len(prompt_input_ids)# input_ids = prompt_input_ids + response_idsinput_ids = torch.cat([prompt_input_ids, response_ids])# labels = [-100] * (num_prefill_tokens - 1) + response_ids + [-100]device = prompt_input_ids.device# 将 labels 的各个部分转换为张量,并移动到指定设备part1 = torch.full((num_prefill_tokens - 1,), -100).to(device)part2 = response_ids.to(device)part3 = torch.tensor([-100]).to(device)# 使用 torch.cat 拼接张量labels = torch.cat((part1, part2, part3))return {"input_ids": input_ids,"labels": labels,"num_tokens": len(input_ids),"message_data": self.message_data[item],"metadata": self.metadata[item],}class TrajectoryDataset(torch.utils.data.Dataset):def __init__(self):super().__init__()self._num_action_tokens = 0self._num_total_tokens = 0self._trajectories = []@propertydef num_action_tokens(self):return self._num_action_tokens.item()@propertydef num_total_tokens(self):return self._num_total_tokensdef update(self, trajectories):num_total_tokens = 0num_action_tokens = 0for data in trajectories:labels = np.array(data["labels"])num_total_tokens += labels.sizenum_action_tokens += (labels >= 0).sum()self._num_action_tokens = num_action_tokensself._num_total_tokens = num_total_tokensself._trajectories = trajectoriesdef dump_jsonl(self, path, tokenizer, debug=False):with open(path, "w", encoding="utf8") as f:for data in self._trajectories:json_line = {"sequence": (data["sequence_text"]if "sequence_text" in dataelse tokenizer.decode(data["input_ids"])),"num_tokens": data["num_tokens"],}json_line["judger_reward"] = data["judger_reward"]json_line["judger_advantage"] = data["judger_advantage"]if debug:json_line["input_ids"] = data["input_ids"]json_line["labels"] = data["labels"]json_str = json.dumps(json_line, ensure_ascii=False)f.write(json_str + "\n")def dump_log(self, path, tokenizer, debug=False):with open(path, "w", encoding="utf8") as f:for data in self._trajectories:log_string = f"[sequence]:\n{data['sequence_text'] if 'sequence_text' in data else tokenizer.decode(data['input_ids'])}\n\n"log_string += f"[num_tokens]: {data['num_tokens']}\n"log_string += f"[judger_reward]: {data['judger_reward']}\n"log_string += f"[judger_advantage]: {data['judger_advantage']}\n"f.write(log_string + "\n\n=======================\n")def __len__(self):return len(self._trajectories)def __getitem__(self, item):return self._trajectories[item]class TrajectoryDatasetWithFilter(TrajectoryDataset):def __init__(self, repeat_k=1, only_keep_1_pair=True):super().__init__()self.repeat_k = repeat_kself.only_keep_1_pair = only_keep_1_pairdef update(self, trajectories):# split trajectories into k groups: (a, a, b, b, c, c) -> [(a, a), (b, b), (c, c)]groups = [trajectories[i : i + self.repeat_k]for i in range(0, len(trajectories), self.repeat_k)]keeped_trajectories = []for group in groups:correctness = [1 if data["judger_reward"] == 1 else 0 for data in group]correct = [data for data in group if data["judger_reward"] == 1]incorrect = [data for data in group if data["judger_reward"] != 1]pass_rate = sum(correctness) / len(correctness)if self.only_keep_1_pair:if pass_rate == 1 or pass_rate == 0:continue# max keep 1 correct and 1 incorrectcorrect = random.choice(correct)incorrect = random.choice(incorrect)correct["pass_rate"] = pass_rateincorrect["pass_rate"] = pass_ratekeeped_trajectories.append(correct)keeped_trajectories.append(incorrect)else:if pass_rate == 1 or pass_rate == 0:continuefor data in group:data["pass_rate"] = pass_ratekeeped_trajectories.append(data)super().update(keeped_trajectories)class TrajectoryCollator(SftCollator):def __call__(self, instances):data = super().__call__(instances)data["judger_rewards"] = [item["judger_reward"] for item in instances]data["judger_advantages"] = [item["judger_advantage"] for item in instances]if "pass_rate" in instances[0]:data["pass_rate"] = [item["pass_rate"] for item in instances]return data

接着运行这个指令以开始用强化学习的方式训练llm(运行下面代码可能需要敲两次回车)
 

export HF_ENDPOINT=https://hf-mirror.compython single_train.py ./oreal/configs/config.py --total_steps 90 --work_dir ./work_dir/my_train

测试训练出的模型
训练好的模型会保存在前面的--work_dir文件夹下,进去后找到最后一次的checkpoint,然后就可以加载了。

用下面的代码测试训练出的模型
 

from transformers import pipeline, AutoTokenizer, AutoModelForCausalLMmodel_name = "./checkpoint_90"  # 替换为你训练好的模型路径,在使用相对路径时注意路径起点。
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)# 创建文本生成 pipeline
text_generator = pipeline("text-generation",model=model,tokenizer=tokenizer,device='cuda:0'
)# 等待用户输入
text_input = "证明费马大定理"# 生成文本
result = text_generator(text_input,max_length=500,# num_return_sequences=3,no_repeat_ngram_size=2,temperature=0.7,top_k=50,top_p=0.95
)# 打印生成的文本
print(result[0]['generated_text'])

运行效果如下:

http://www.lryc.cn/news/587362.html

相关文章:

  • Leaflet面试题及答案(61-80)
  • Flink数据流高效写入MySQL实战
  • XCZU2CG-2SFVC784I Xilinx FPGA AMD Zynq UltraScale+ MPSoC
  • Vivado ILA抓DDR信号(各种IO信号:差分、ISERDES、IOBUFDS等)
  • 六、深度学习——NLP
  • 无缝衔接直播流体验
  • 早期 CNN 的经典模型—卷积神经网络(LeNet)
  • 板凳-------Mysql cookbook学习 (十一--------8)
  • 【深度学习新浪潮】什么是新视角合成?
  • STM32-第五节-TIM定时器-1(定时器中断)
  • JAVA并发——synchronized的实现原理
  • 特征选择方法
  • 一文打通MySQL任督二脉(事务、索引、锁、SQL优化、分库分表)
  • GraphRAG Docker化部署,接入本地Ollama完整技术指南:从零基础到生产部署的系统性知识体系
  • AEC线性处理
  • 【iOS】方法与消息底层分析
  • 【设计模式】命令模式 (动作(Action)模式或事务(Transaction)模式)宏命令
  • phpMyAdmin:一款经典的MySQL在线管理工具又回来了
  • 【RA-Eco-RA6E2-64PIN-V1.0 开发板】ADC 电压的 LabVIEW 数据采集
  • 第一个Flink 程序 WordCount,词频统计(批处理)
  • git实操
  • 鸿蒙项目构建配置
  • 区分三种IO模型和select/poll/epoll
  • Java设计模式之行为型模式(命令模式)
  • Spring Boot + MyBatis 实现用户登录功能详解(基础)
  • JAVA学习笔记 JAVA开发环境部署-001
  • 深入分析---虚拟线程VS传统多线程
  • 力扣刷题记录(c++)09
  • 在 OCI 生成式 AI 上搭一个「指定地区拉面店 MCP Server」——从 0 到 1 实战记录
  • opencv中contours的使用