当前位置: 首页 > news >正文

HuggingFace学习笔记--BitFit高效微调

1--BitFit高效微调

        BitFit,全称是 bias-term fine-tuning,其高效微调只去微调带有 bias 的参数,其余参数全部固定;

2--实例代码

from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer# 分词器
tokenizer = AutoTokenizer.from_pretrained("Langboat/bloom-1b4-zh")# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):MAX_LENGTH = 256input_ids, attention_mask, labels = [], [], []instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")response = tokenizer(example["output"] + tokenizer.eos_token)input_ids = instruction["input_ids"] + response["input_ids"]attention_mask = instruction["attention_mask"] + response["attention_mask"]labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}if __name__ == "__main__":# 加载数据集dataset = load_from_disk("./PEFT/data/alpaca_data_zh")# 处理数据tokenized_ds = dataset.map(process_func, remove_columns = dataset.column_names)# print(tokenizer.decode(tokenized_ds[1]["input_ids"]))# print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))# 创建模型model = AutoModelForCausalLM.from_pretrained("Langboat/bloom-1b4-zh", low_cpu_mem_usage=True)# 基于bitfit只训练带有bias的参数for name, param in model.named_parameters():if "bias" not in name:param.requires_grad = False# 训练参数args = TrainingArguments(output_dir = "./chatbot",per_device_train_batch_size = 1,gradient_accumulation_steps = 8,logging_steps = 10,num_train_epochs = 1)# trainertrainer = Trainer(model = model,args = args,train_dataset = tokenized_ds,data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True))# 训练模型trainer.train()# 模型推理pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "output = pipe(ipt, max_length=256, do_sample=True)print(output)

结果:

http://www.lryc.cn/news/253831.html

相关文章:

  • 阅读笔记|A Survey of Large Language Models
  • JSP 设置静态文件资源访问路径
  • 【Pytorch】Visualization of Feature Maps(4)——Saliency Maps
  • java第三十课
  • Scala--2
  • 【SQL SERVER】定时任务
  • MyBatis-Plus学习笔记(无脑cv即可)
  • 【VUE】watch 监听失效
  • python的异常处理批量执行网络设备的巡检命令
  • react native 环境准备
  • PGSQL(PostgreSQL)数据库安装教程
  • 识别和修复网站上损坏链接的最佳实践
  • 使用Navicat连接MySQL出现的一些错误
  • 4G基站BBU、RRU、核心网设备
  • iphone/安卓手机如何使用burp抓包
  • springboot云HIS医院信息综合管理平台源码
  • 【视觉SLAM十四讲学习笔记】第三讲——四元数
  • Linux系统之部署Plik临时文件上传系统
  • 【EI征稿中#先投稿,先送审#】第三届网络安全、人工智能与数字经济国际学术会议(CSAIDE 2024)
  • 『亚马逊云科技产品测评』活动征文|基于亚马逊云EC2搭建OA系统
  • Mysql更新varchar存储的Josn数据
  • JSON.stringify与JSON.parse详解与实践
  • vue 基础
  • 使用axios下载后端接口返回的文件流格式文件
  • 在macOS上使用Homebrew安装PHP的完整指南
  • 图片处理OpenCV IMDecode模式说明【生产问题处理】
  • 吹响AI技术应用的号角
  • C //例10.1 从键盘输入一些字符,逐个把它们送到磁盘上去,直到用户输入一个“#”为止。
  • ARM预取侧信道(Prefetcher Side Channels)攻击与防御
  • 数据结构 | 二叉树的各种遍历