当前位置: 首页 > news >正文

微调Qwen2大语言模型加入领域知识

目录

  • 试用Qwen2做推理
  • 安装LLaMA-Factory
  • 使用自有数据集微调Qwen2
  • 验证微调效果

试用Qwen2做推理

参考:https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html

from transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" # the device to load the model onto# Now you do not need to add "trust_remote_code=True"
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B-Instruct",torch_dtype="auto",device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")# Instead of using model.chat(), we directly use model.generate()
# But you need to use tokenizer.apply_chat_template() to format your inputs as shown below
prompt = "Give me a short introduction to large language model."
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)# Directly use generate() and tokenizer.decode() to get the output.
# Use `max_new_tokens` to control the maximum output length.
generated_ids = model.generate(model_inputs.input_ids,max_new_tokens=512
)
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(response)

执行此操作后模型会缓存到如下目录
/root/.cache/huggingface/hub/models–Qwen–Qwen2-7B-Instruct/

安装LLaMA-Factory

参考:https://github.com/hiyouga/LLaMA-Factory

git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
pip install -e ".[torch,metrics]"

使用自有数据集微调Qwen2

参考:https://qwen.readthedocs.io/en/latest/training/SFT/llama_factory.html

启动web ui 界面来进行微调

llamafactory-cli webui

在这里插入图片描述

数据集选择提前准备好的自有数据集train_data.json
编辑LLaMA-Factory/data路径下的dataset_info.json,把自有数据集加入进去
在这里插入图片描述

按照自己需求配置训练参数,所有的参数都配置好之后,点一下“预览命令”,确认命令没有问题之后,就可以点击“开始”进行训练了。训练的过程中可以看到 loss的变化曲线、训练耗时等。

参考:https://blog.csdn.net/u012505617/article/details/137864437

验证微调效果

可以直接点击webui界面中的“chat”选项卡,然后点击“加载模型”,然后就可以进行对话了,不过使用A100实测非常慢,一次会话花了四十多分钟。

所以建议在webui 界面训练好模型之后点击“Export”选项卡,然后,在“模型路径”中输入原始模型路径,然后在“检查点路径”中选择自己微调得到的 checkpoint路径,然后在“最大分块大小(GB)”中设置为4,同时设置一下导出目录,最后点击“开始导出”,就可以看到输出的模型了

本地部署模型,并做推理测试
参考:https://qwen.readthedocs.io/en/latest/run_locally/llama.cpp.html

这里使用的是vLLM
参考:https://qwen.readthedocs.io/en/latest/deployment/vllm.html

from transformers import AutoTokenizer
from vllm import LLM, SamplingParams# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B-Instruct")# Pass the default decoding hyperparameters of Qwen2-7B-Instruct
# max_tokens is for the maximum length for generation.
sampling_params = SamplingParams(temperature=0.1, top_p=0.1, repetition_penalty=1.05, max_tokens=512) #这里想要推理的结果跟我们微调的预期严格一致,所以调低了前两个参数的值。# Input the model name or path. Can be GPTQ or AWQ models.
llm = LLM(model="LLaMA-Factory/saves/Qwen2-7B-Chat/sft") #这里填写前面Export时填写的Path# Prepare your prompts
prompt = "这里输入与微调数据集中类似的问题来验证"
messages = [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True
)# generate outputs
outputs = llm.generate([text], sampling_params)# Print the outputs.
for output in outputs:prompt = output.promptgenerated_text = output.outputs[0].textprint(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
http://www.lryc.cn/news/399126.html

相关文章:

  • 【Linux】内核文件系统系统调用流程摸索
  • 【HZHY-AI300G智能盒试用连载体验】文档资料
  • Linux--深入理与解linux文件系统与日志文件分析
  • Postman 中的 API 安全性测试:最佳实践与技巧
  • PTC可复位保险丝 vs 传统型保险丝:全面对比分析
  • 深入了解Rokid UXR2.0 SDK内置的Unity AR Glass开发组件
  • Lottery 分布式抽奖(个人向记录总结)
  • 我的AI音乐梦:ChatGPT帮我做专辑
  • 新手-前端生态
  • C#中的类
  • 探索数据库编程:基础与进阶之存储函数
  • Count数据转换为TPM数据方法整理-常规方法、DGEobj.utils和IOBR包
  • 简易限流实现
  • 用Qwt进行图表和数据可视化开发
  • sqlalchemy使用with_entities返回指定数据列
  • express
  • HTML网页大设计-家乡普宁德安里
  • 深度学习:从数据采集到模型测试的全面指南
  • Excel第29享:基于sum嵌套sumifs的多条件求和
  • Elasticsearch:Node.js ECS 日志记录 - Morgan
  • ChatGPT对话:Python程序自动模拟操作网页,无法弹出下拉列表框
  • Unity 之 抖音小游戏集成排行榜功能详解
  • 【学习笔记】Redis学习笔记——第13章 客户端
  • Android中的JSON解析:从基础到实践
  • 力扣-回溯法
  • 240713_昇思学习打卡-Day25-LSTM+CRF序列标注(4)
  • python requests关闭https校验
  • PG大会周五于杭州举办;Pika发布4.0;阿里云MySQL上线Zero-ETL集成能力
  • 虚拟机vmware网络设置
  • 数学建模国赛入门指南