DeepSeek-R1知识蒸馏和微调实践(一)源码
蒸馏和微调实践见上一篇文章:https://blog.csdn.net/wy746801669wy/article/details/149118211?spm=1001.2014.3001.5501
核心源代码r1_distill.py如下,其它源码文件见附件:
# coding=gbk
from openai import OpenAI
from modelscope.msdatasets import MsDataset
import threading
import time
import json
import os, sys, csv
import argparse
import subprocess as sp
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplitAPI_KEY=os.getenv("API_KEY")
BASE_URL='https://dashscope.aliyuncs.com/compatible-mode/v1'
MODEL_NAME='deepseek-r1'PROMPT='''
# 角色
我是爸爸,会耐心解答女儿提出的问题。# 注意事项
- 女儿目前读小学一年级,解答时请考虑她的理解水平。
- 如果女儿提出的问题太难,你可以适当简化问题,让她能够理解。
- 请不要使用专业术语和概念,尽量用通俗易懂的语言解答。# 爸爸的风格
- 喜欢循序渐进的讲解,逐步引导女儿理解问题
- 喜欢用生活中的例子来解释抽象的概念
- 理性思维,喜欢用逻辑推理的方式解答问题
- 经常会叫女儿的名字“赛西”,以便保证她的注意力
- 反复确认女儿有没有听懂,通过提问和重复解答的方式,确保她理解了问题
- 偶尔抛出反问或有趣的问题,引发女儿思考# 来自女儿的提问
{question}
'''PROMPT_PURIFICATION='''
# 角色
我是老师,会言简意赅地批阅试卷。# 注意事项
- \n\n\n\n\n之前的第一大段中文用一句话总结出来
- \n\n\n\n\n之后的英文标准答案翻译成中文
- 最后比较\n\n\n\n\n前后含义是否一致,用是或否来回答
- 格式按照“第一段总结:xxx。\n第二段翻译:xxx。\n含义是否一致:是/否” 来输出# 老师的风格
- 言简意赅# 来自考场的材料
{question}
'''THREAD=30
SAMPLES=10class R1Generator:def __init__(self,threads,dataset,samples, prompt):self.client=OpenAI(api_key=API_KEY,base_url=BASE_URL)self.idx=0self.threads=threadsself.dataset=datasetself.samples=samplesself.prompt=promptself.mutex=threading.Lock()def generate(self,question):completion=self.client.chat.completions.create(model=MODEL_NAME,messages=[{'role': 'user', 'content': self.prompt.format(question=question)},])return completion.choices[0].message.reasoning_content,completion.choices[0].message.contentdef begin(self):self.idx=0self.progress=0self.result=[None]*self.samplesself.thread_handlers=[]for i in range(self.threads):t=threading.Thread(target=self.thread_main)t.start()self.thread_handlers.append(t)def join(self):while True:with self.mutex:print(f'Progress: {self.progress}/{self.samples}',end='\r')if self.progress>=self.samples:breaktime.sleep(1)for t in self.thread_handlers:t.join()return [res for res in self.result if res is not None]def thread_main(self):while True:with self.mutex:if self.idx>=self.samples:breakcur_idx=self.idxself.idx+=1try:question=self.dataset[cur_idx]['question']reasoning,answer=self.generate(question)self.result[cur_idx]=(question,reasoning,answer)except Exception as e:print('cur_idx: ', cur_idx, ' question: ', question)print(e)with self.mutex:self.progress+=1def create_purifying_csv(train_dataset):gsm8k = {}ds_r1_check_data = []csv_file_path = './ds_r1_purifying_data.csv' # 定义 CSV 文件路径for i in train_dataset:gsm8k[i['question']] = i['answer']with open(r'r1_distill.txt', 'r', encoding='utf-8') as f2:for line in f2:line = json.loads(line)ds_r1_check_data.append({'question': line['answer'] + '\n\n\n\n\n' + gsm8k[line['question']]})# 写入 CSV 文件with open(csv_file_path, 'w', newline='', encoding='utf-8') as csv_file:# 获取字段名(假设所有字典有相同的键)fieldnames = ds_r1_check_data[0].keys() if ds_r1_check_data else []# 创建 CSV 写入器writer = csv.DictWriter(csv_file, fieldnames=fieldnames)# 写入表头writer.writeheader()# 写入数据行for row in ds_r1_check_data:writer.writerow(row)print(f'数据已成功写入 {csv_file_path}')def parse_args():parser = argparse.ArgumentParser(description='Knowledge Distillation Processor')# 添加命令行参数parser.add_argument('--ms_dataset_id', type=str, default='modelscope/gsm8k', required=False,help='Dataset ID')parser.add_argument('--purifying', action='store_true', help='Enable purifying mode')# 解析命令行参数args = parser.parse_args()return argsif __name__=='__main__':args = parse_args()dataset_id = args.ms_dataset_idpurifying = args.purifyingif purifying:sampled_dataset = MsDataset.load(dataset_id, subset_name='main', split='train')create_purifying_csv(sampled_dataset)csv_file = MsDataset.load('./ds_r1_purifying_data.csv')r1 = R1Generator(threads=THREAD, dataset=csv_file, samples=SAMPLES, prompt=PROMPT_PURIFICATION)r1.begin()result = r1.join()with open('r1_distill_purifying.txt', 'w') as f:for res in result:question, reasoning, answer = resf.write(json.dumps({'question': question, 'reasoning': reasoning, 'answer': answer}) + '\n')sp.run(['bash', './purification_processor.sh'])else:sampled_dataset = MsDataset.load(dataset_id, subset_name='main', split='train')r1 = R1Generator(threads=THREAD, dataset=sampled_dataset, samples=SAMPLES, prompt=PROMPT)r1.begin()result=r1.join()with open('r1_distill.txt','w') as f:for res in result:question,reasoning,answer=resf.write(json.dumps({'question':question,'reasoning':reasoning,'answer':answer})+'\n')sp.run(['python', 'transfer2openmind.py', 'r1_distill.txt'])
其它源码文件链接如下:
https://pan.quark.cn/s/b1267c853a39