当前位置: 首页 > news >正文

大模型蒸馏(distillation)---从DeepseekR1-1.5B到Qwen-2.5-1.5B蒸馏

 

目录

1.1 蒸馏目标

2 环境准备

2.1依赖库安装

2.2 硬件要求

2.3 模型与数据集下载

2.3.1 教师模型下载

2.3.2 学生模型下载

 2.3.3 数据集准备或下载

 3.过程日志

 4. 模型加载与配置

4.1 加载教师模型

4.2 加载学生模型

4.3 数据预处理函数  

 4.4 数据收集器

4.5 定义训练参数

4.6 定义蒸馏配置

4.7 定义训练配置

4.8 创建蒸馏器 

4.9 开始蒸馏 

 5. 完整代码

6.结合上述内容和TextBrewer,自己重新整理了一遍代码,仅供参考:

1.1 蒸馏目标

将 DeepSeek 的推理能力迁移到 Qwen-2.5;

确保学生模型与 Qwen-2.5 的原始功能(如对话、多语言支持)兼容。

2 环境准备

2.1依赖库安装

pip install torch torchvision transformers datasets2.2
pip install accelerate # 加速分布式训练
pip install evaluate # 评估指标

2.2 硬件要求

GPU:建议使用单张或多张 NVIDIA GPU(如 V100、A100)建议至少 24GB。

CUDA:安装与 PyTorch 兼容的 CUDA 版本。

2.3 模型与数据集下载

2.3.1 教师模型下载

Qwen-2.5-1.5B从huggingface 下载,离线下载方式(从https://hf-mirror.com离线下载):

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download Qwen/Qwen2.5-1.5B --local-dir ./models/qwen2.5-1.5B --local-dir-use-symlinks False

2.3.2 学生模型下载

Qwen-2.5-1.5B

$env:HF_ENDPOINT = "https://hf-mirror.com"huggingface-cli download Qwen/Qwen2.5-1.5B --local-dir ./models/qwen2.5-1.5B --local-dir-use-symlinks False

 2.3.3 数据集准备或下载

建议使用大规模文本数据集(如 wikitex、Wikipedia、BooksCorpus、OpenWebText 等)。离线下载地址(从https://www.kaggle.com/datasets/jayanthbontha/wikitext下载)

 3.过程日志

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")

 4. 模型加载与配置

4.1 加载教师模型

# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)

4.2 加载学生模型

# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)

4.3 数据预处理函数  

dataset.map() 是 Hugging Face datasets 库中用于对数据集进行批量预处理的核心方法。当 batched=True 时,它会将数据集分批(batch)传递给 preprocess_function,而不是逐个样本处理。这种批量处理方式效率更高,尤其适合大规模数据集。

# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)

 4.4 数据收集器

DataCollatorForLanguageModeling 是 Hugging Face transformers 库中的一个数据整理类(Data Collator),用于在训练语言模型(如 BERT、GPT 等)时动态生成训练样本。它可以根据任务需求(如掩码语言模型(MLM)或因果语言模型(CLM))对输入数据进行预处理。

# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)

mlm(关键参数):作用:控制是否启用**掩码语言模型(MLM)**模式。

mlm=True:随机掩码输入中的部分 token(如 BERT 训练方式),生成 [MASK] 标记。

mlm=False:禁用掩码,适用于因果语言模型(CLM)(如 GPT 训练方式),输入和标签为原始 token 序列。

4.5 定义训练参数

# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)

4.6 定义蒸馏配置

# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度hard_label_weight=0.5,  # 真实标签损失权重kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)intermediate_matches=[  # 中间层匹配配置{"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失}])

4.7 定义训练配置

# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择log_dir="./logs",                                     # 日志目录output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态))

4.8 创建蒸馏器 

# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)

4.9 开始蒸馏 

# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)args=training_args,   # 训练参数(如学习率、批次大小、设备等)train_dataset=train_dataset,  # 训练数据集(包含输入和标签)eval_dataset=eval_dataset,    # 验证数据集(用于评估模型性能)data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理)# 开始模型训练trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等logger.info("Training finished")  # 记录训练结束日志

 5. 完整代码

import osimport torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, \TrainingArguments
from textbrewer import GeneralDistiller, TrainingConfig, DistillationConfig
from datasets import load_dataset
import logging# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)# 获取当前脚本文件的绝对路径
current_script_path = os.path.abspath(__file__)
logger.info(f"Current script path: {current_script_path}")# 获取当前脚本文件所在的目录
current_script_dir = os.path.dirname(current_script_path)
logger.info(f"Current script directory: {current_script_dir}")# 加载教师模型(DeepSeek-R1:1.5B)
teacher_model_name = os.path.join(current_script_dir, "../models/DeepSeek-R1-Distill-Qwen-1.5B")
logger.info(f"Loading teacher model: {teacher_model_name}")
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name,local_files_only=True
)
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,local_files_only=True
)# 加载学生模型(Qwen)
student_model_name = os.path.join(current_script_dir, "../models/qwen2.5-1.5B")  # 确保模型名称正确
logger.info(f"Loading student model: {student_model_name}")
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name,local_files_only=True
)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name,local_files_only=True
)# 准备数据集
datasets_name = os.path.join(current_script_dir, "../models/Dataset/wikitext-2-raw/")  # 确保模型名称正确
data_files = {"train": datasets_name+"wiki.train.raw","test": datasets_name+"wiki.test.raw"
}
logger.info(f"Loading dataset from local files: {data_files}")
dataset = load_dataset("text", data_files=data_files)
train_dataset = dataset["train"]
eval_dataset = dataset["test"]# 数据预处理
logger.info(f"Preprocess_function")
def preprocess_function(examples):return teacher_tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)logger.info("Preprocessing train dataset")
train_dataset = train_dataset.map(preprocess_function, batched=True)
logger.info("Preprocessing eval dataset")
eval_dataset = eval_dataset.map(preprocess_function, batched=True)# 数据收集器
logger.info("DataCollatorForLanguageModeling")
data_collator = DataCollatorForLanguageModeling(tokenizer=teacher_tokenizer, mlm=False)# 定义训练参数
logger.info("Creating trainer")
training_args = TrainingArguments(output_dir="./results",            # 训练结果保存路径eval_strategy="epoch",             # 每个 epoch 结束时评估learning_rate=5e-5,                # 学习率(默认 5e-5 是常见选择)per_device_train_batch_size=2,     # 每个设备的训练 batch size(GPU 单卡)per_device_eval_batch_size=2,      # 每个设备的评估 batch sizenum_train_epochs=3,                # 训练轮次(3 轮可能较短,需根据任务调整)weight_decay=0.01,                 # 权重衰减(L2 正则化)logging_dir="./logs",              # 日志保存路径logging_steps=100,                 # 每 100 步记录一次日志fp16=False,                        # 是否启用混合精度训练(建议开启)gradient_accumulation_steps=4,     # 梯度累积步数(等效 batch_size=8)report_to="tensorboard",           # 使用 TensorBoard 记录训练过程# tensorboard_dir="./tensorboard"  # 可选:指定 TensorBoard 日志目录
)# 定义蒸馏配置  weight:添加权重,"loss": "mse"
logger.info("Creating distillation config")
distill_config = DistillationConfig(temperature=2.0,  # 温度参数,控制软标签的平滑程度hard_label_weight=0.5,  # 真实标签损失权重kd_loss_type="ce",      # 知识蒸馏损失类型(交叉熵)intermediate_matches=[  # 中间层匹配配置{"layer_T": 6,    # 教师模型的第6层"layer_S": 6,    # 学生模型的第6层"feature": "hidden",  # 匹配隐藏层特征"weight": 1.0,   # 中间层损失权重"loss": "mse"    # 使用均方误差损失}]
)# 定义训练配置
logger.info("Creating training config")
train_config = TrainingConfig(device="cuda" if torch.cuda.is_available() else "cpu",  # 设备选择log_dir="./logs",                                     # 日志目录output_dir="./outputs"                                # 模型输出目录# save_best_model=True,  # 是否保存最佳模型(注释状态)# save_last_model=True,  # 是否保存最后模型(注释状态)# save_model_every_epoch=True,  # 是否每轮保存模型(注释状态)# tensorboard_dir="./tensorboard"  # TensorBoard 日志目录(注释状态)
)# 创建蒸馏器
logger.info("Creating distiller")
distiller = GeneralDistiller(train_config=train_config,        # 训练配置(包含设备、路径等)distill_config=distill_config,    # 蒸馏配置(温度、损失权重等)model_T=teacher_model,            # 教师模型model_S=student_model,            # 学生模型adaptor_T=None,                   # 教师模型适配器(未配置)adaptor_S=None                    # 学生模型适配器(未配置)
)# 开始蒸馏
with distiller:  # 使用蒸馏器上下文管理器,确保资源正确初始化和释放logger.info("Starting training")  # 记录训练开始日志# 初始化 Trainer,集成模型蒸馏配置trainer = Trainer(model=student_model,  # 学生模型(需要训练的小模型)args=training_args,  # 训练参数(如学习率、批次大小、设备等)train_dataset=train_dataset,  # 训练数据集(包含输入和标签)eval_dataset=eval_dataset,  # 验证数据集(用于评估模型性能)data_collator=data_collator,  # 数据批量处理函数(将单条数据组合成批次)# processing_class=teacher_tokenizer  # 注意:此处可能存在问题(见下方说明)# 正确做法:适配器或数据处理逻辑应在蒸馏配置中处理)# 开始模型训练trainer.train()  # 启动训练循环,包含前向传播、损失计算、反向传播等trainer.save_model()logger.info("Training finished")  # 记录训练结束日志
复制代码

参考:

模型蒸馏(Distillation)案例--从DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏 - InProsperity - 博客园

模型蒸馏(Distillation)案例--从DeepSeek-R1-1.5B 到 Qwen-2.5-1.5B 的模型蒸馏-CSDN博客

6.结合上述内容和TextBrewer,自己重新整理了一遍代码,仅供参考:

import os
import torch
import logging
from transformers import (AutoModelForCausalLM,AutoTokenizer,DataCollatorForLanguageModeling,get_linear_schedule_with_warmup
)
from textbrewer import GeneralDistiller, TrainingConfig, DistillationConfig
from datasets import load_dataset
from torch.optim import AdamW# 配置日志
logging.basicConfig(level=logging.INFO,format='%(asctime)s - %(levelname)s - %(message)s',handlers=[logging.FileHandler("distillation.log"),logging.StreamHandler()]
)
logger = logging.getLogger(__name__)# 设备设置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")# ======================
# 1. 加载模型和Tokenizer
# ======================
def load_models_and_tokenizers():"""加载教师和学生模型"""# 教师模型 (DeepSeek-R1 1.5B)teacher_model_name = "deepseek-ai/deepseek-r1-1.5b"logger.info(f"Loading teacher model: {teacher_model_name}")teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)# 学生模型 (Qwen 1.5B)student_model_name = "Qwen/Qwen1.5-1.8B"logger.info(f"Loading student model: {student_model_name}")student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)student_model = AutoModelForCausalLM.from_pretrained(student_model_name,torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)# 对齐tokenizer(关键步骤!)if teacher_tokenizer.vocab != student_tokenizer.vocab:logger.warning("Tokenizers not aligned, adding special tokens...")student_tokenizer.add_special_tokens({'pad_token': '[PAD]'})student_model.resize_token_embeddings(len(student_tokenizer))return teacher_model, student_model, teacher_tokenizer, student_tokenizer# ======================
# 2. 数据准备
# ======================
def prepare_data(student_tokenizer):"""加载并预处理数据"""# 加载数据集(示例使用wikitext)dataset = load_dataset("wikitext", "wikitext-2-raw-v1")# 预处理函数def preprocess_function(examples):return student_tokenizer(examples["text"],truncation=True,padding="max_length",max_length=512,return_tensors="pt")# 处理数据集train_dataset = dataset["train"].map(preprocess_function,batched=True,remove_columns=["text"])eval_dataset = dataset["validation"].map(preprocess_function,batched=True,remove_columns=["text"])# 数据收集器data_collator = DataCollatorForLanguageModeling(tokenizer=student_tokenizer,mlm=False)return train_dataset, eval_dataset, data_collator# ======================
# 3. 蒸馏配置
# ======================
def get_distillation_config():"""配置蒸馏参数"""return DistillationConfig(temperature=2.0,  # 初始温度temperature_scheduler=lambda x: max(0.5, 2.0 - 0.1 * x),  # 温度衰减hard_label_weight=0.3,  # 真实标签权重kd_loss_weight=0.7,  # 蒸馏损失权重kd_loss_type="ce",  # 交叉熵损失intermediate_matches=[{"layer_T": [6, 12, 18],  # 教师模型层"layer_S": [3, 6, 9],  # 学生模型层"feature": "hidden",  # 隐藏状态"loss": "cosine",  # 余弦相似度损失"weight": 0.5,"proj": ["linear", 1536, 1024]  # 维度投影},{"layer_T": [9, 15],"layer_S": [4, 7],"feature": "attention",  # 注意力矩阵"loss": "mse","weight": 0.3}])# ======================
# 4. 训练配置
# ======================
def get_training_config():"""配置训练参数"""return TrainingConfig(output_dir="./distill_output",device=device,fp16=torch.cuda.is_available(),gradient_accumulation_steps=4,ckpt_frequency=500,  # 每500步保存检查点log_steps=100,max_grad_norm=1.0,  # 梯度裁剪save_optimizer=False  # 为节省空间不保存优化器)# ======================
# 5. 优化器设置
# ======================
def get_optimizer(model):"""配置优化器和学习率调度"""optimizer = AdamW(model.parameters(),lr=5e-5,weight_decay=0.01)scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps=500,num_training_steps=3000)return optimizer, scheduler# ======================
# 主函数
# ======================
def main():# 1. 加载模型和数据teacher_model, student_model, teacher_tokenizer, student_tokenizer = load_models_and_tokenizers()train_dataset, eval_dataset, data_collator = prepare_data(student_tokenizer)# 2. 配置蒸馏distill_config = get_distillation_config()train_config = get_training_config()# 3. 初始化蒸馏器distiller = GeneralDistiller(train_config=train_config,distill_config=distill_config,model_T=teacher_model,model_S=student_model,adaptor_T=None,  # 使用默认适配器adaptor_S=None)# 4. 准备优化器optimizer, scheduler = get_optimizer(student_model)# 5. 开始蒸馏logger.info("Starting distillation...")with distiller:distiller.train(optimizer=optimizer,scheduler=scheduler,train_dataset=train_dataset,eval_dataset=eval_dataset,batch_size=2,num_epochs=3,data_collator=data_collator,callback=None)# 6. 保存最终模型student_model.save_pretrained("./final_student_model")student_tokenizer.save_pretrained("./final_student_model")logger.info("Distillation completed!")if __name__ == "__main__":main()

另外,可以了解Text Generation WebUI,集成不同大模型进行推理,测试。

https://github.com/oobabooga/text-generation-webui

http://www.lryc.cn/news/600797.html

相关文章:

  • UniappDay03
  • 【Canvas与旗帜】条纹版大明三辰旗
  • AI是否会终结IT职业?深度剖析IT行业的“涌现”与重构
  • 慧星云新增大模型服务:多款大模型轻松调用
  • C++:STL中vector的使用和模拟实现
  • MySQL的底层原理--InnoDB数据页结构
  • 人大金仓 kingbase 连接数太多, 清理数据库连接数
  • 基于匿名管道的多进程任务池实现与FD泄漏解决方案
  • VUE2 学习笔记7 v-model、过滤器
  • 6.数组和字符串
  • ChatIm项目文件上传与获取
  • 拉普拉斯方程的径向解法
  • opencv学习(图像金字塔)
  • DriverManager在rt.jar里,凭什么能加载到classpath下的驱动?
  • Vue当中背景图无法占满屏幕的解决方法
  • 记一次腾讯云临时密钥接管存储桶
  • 零基础 “入坑” Java--- 十四、【练习】图书小系统
  • mrpc框架项目的AI总结
  • 热传导问题Matlab有限元编程 :工业级热仿真核心技术-搭建热传导求解器【含案例源码】
  • 【ELasticsearch】节点角色分类与作用解析
  • ubuntu下docker安装thingsboard物联网平台详细记录(附每张图)
  • 考研复习-数据结构-第八章-排序
  • 求hom_math_2d的角度值
  • URL与URI:互联网世界的“门牌号“与“身份证“
  • DocC的简单使用
  • ICMP报文工作原理
  • Linux如何执行系统调用及高效执行系统调用:深入浅出的解析
  • Python 数据分析(二):Matplotlib 绘图
  • 斐波那契数列加强版 快速矩阵幂
  • 特产|基于SSM+vue的南阳特产销售平台(源码+数据库+文档)