当前位置: 首页 > news >正文

DeepSeek-R1知识蒸馏和微调实践(一)源码

蒸馏和微调实践见上一篇文章:https://blog.csdn.net/wy746801669wy/article/details/149118211?spm=1001.2014.3001.5501

核心源代码r1_distill.py如下,其它源码文件见附件:

# coding=gbk
from openai import OpenAI
from modelscope.msdatasets import MsDataset
import threading
import time 
import json 
import os, sys, csv
import argparse
import subprocess as sp
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplitAPI_KEY=os.getenv("API_KEY")
BASE_URL='https://dashscope.aliyuncs.com/compatible-mode/v1'
MODEL_NAME='deepseek-r1'PROMPT='''
# 角色
我是爸爸,会耐心解答女儿提出的问题。# 注意事项
- 女儿目前读小学一年级,解答时请考虑她的理解水平。
- 如果女儿提出的问题太难,你可以适当简化问题,让她能够理解。
- 请不要使用专业术语和概念,尽量用通俗易懂的语言解答。# 爸爸的风格
- 喜欢循序渐进的讲解,逐步引导女儿理解问题
- 喜欢用生活中的例子来解释抽象的概念
- 理性思维,喜欢用逻辑推理的方式解答问题
- 经常会叫女儿的名字“赛西”,以便保证她的注意力
- 反复确认女儿有没有听懂,通过提问和重复解答的方式,确保她理解了问题
- 偶尔抛出反问或有趣的问题,引发女儿思考# 来自女儿的提问
{question}
'''PROMPT_PURIFICATION='''
# 角色
我是老师,会言简意赅地批阅试卷。# 注意事项
- \n\n\n\n\n之前的第一大段中文用一句话总结出来
- \n\n\n\n\n之后的英文标准答案翻译成中文
- 最后比较\n\n\n\n\n前后含义是否一致,用是或否来回答
- 格式按照“第一段总结:xxx。\n第二段翻译:xxx。\n含义是否一致:是/否”  来输出# 老师的风格
- 言简意赅# 来自考场的材料
{question}
'''THREAD=30
SAMPLES=10class R1Generator:def __init__(self,threads,dataset,samples, prompt):self.client=OpenAI(api_key=API_KEY,base_url=BASE_URL)self.idx=0self.threads=threadsself.dataset=datasetself.samples=samplesself.prompt=promptself.mutex=threading.Lock()def generate(self,question):completion=self.client.chat.completions.create(model=MODEL_NAME,messages=[{'role': 'user', 'content': self.prompt.format(question=question)},])return completion.choices[0].message.reasoning_content,completion.choices[0].message.contentdef begin(self):self.idx=0self.progress=0self.result=[None]*self.samplesself.thread_handlers=[]for i in range(self.threads):t=threading.Thread(target=self.thread_main)t.start()self.thread_handlers.append(t)def join(self):while True:with self.mutex:print(f'Progress: {self.progress}/{self.samples}',end='\r')if self.progress>=self.samples:breaktime.sleep(1)for t in self.thread_handlers:t.join()return [res for res in self.result if res is not None]def thread_main(self):while True:with self.mutex:if self.idx>=self.samples:breakcur_idx=self.idxself.idx+=1try:question=self.dataset[cur_idx]['question']reasoning,answer=self.generate(question)self.result[cur_idx]=(question,reasoning,answer)except Exception as e:print('cur_idx: ', cur_idx, ' question: ', question)print(e)with self.mutex:self.progress+=1def create_purifying_csv(train_dataset):gsm8k = {}ds_r1_check_data = []csv_file_path = './ds_r1_purifying_data.csv'  # 定义 CSV 文件路径for i in train_dataset:gsm8k[i['question']] = i['answer']with open(r'r1_distill.txt', 'r', encoding='utf-8') as f2:for line in f2:line = json.loads(line)ds_r1_check_data.append({'question': line['answer'] + '\n\n\n\n\n' + gsm8k[line['question']]})# 写入 CSV 文件with open(csv_file_path, 'w', newline='', encoding='utf-8') as csv_file:# 获取字段名(假设所有字典有相同的键)fieldnames = ds_r1_check_data[0].keys() if ds_r1_check_data else []# 创建 CSV 写入器writer = csv.DictWriter(csv_file, fieldnames=fieldnames)# 写入表头writer.writeheader()# 写入数据行for row in ds_r1_check_data:writer.writerow(row)print(f'数据已成功写入 {csv_file_path}')def parse_args():parser = argparse.ArgumentParser(description='Knowledge Distillation Processor')# 添加命令行参数parser.add_argument('--ms_dataset_id', type=str, default='modelscope/gsm8k', required=False,help='Dataset ID')parser.add_argument('--purifying', action='store_true', help='Enable purifying mode')# 解析命令行参数args = parser.parse_args()return argsif __name__=='__main__':args = parse_args()dataset_id = args.ms_dataset_idpurifying = args.purifyingif purifying:sampled_dataset = MsDataset.load(dataset_id, subset_name='main', split='train')create_purifying_csv(sampled_dataset)csv_file = MsDataset.load('./ds_r1_purifying_data.csv')r1 = R1Generator(threads=THREAD, dataset=csv_file, samples=SAMPLES, prompt=PROMPT_PURIFICATION)r1.begin()result = r1.join()with open('r1_distill_purifying.txt', 'w') as f:for res in result:question, reasoning, answer = resf.write(json.dumps({'question': question, 'reasoning': reasoning, 'answer': answer}) + '\n')sp.run(['bash', './purification_processor.sh'])else:sampled_dataset = MsDataset.load(dataset_id, subset_name='main', split='train')r1 = R1Generator(threads=THREAD, dataset=sampled_dataset, samples=SAMPLES, prompt=PROMPT)r1.begin()result=r1.join()with open('r1_distill.txt','w') as f:for res in result:question,reasoning,answer=resf.write(json.dumps({'question':question,'reasoning':reasoning,'answer':answer})+'\n')sp.run(['python', 'transfer2openmind.py', 'r1_distill.txt'])

其它源码文件链接如下:
https://pan.quark.cn/s/b1267c853a39

http://www.lryc.cn/news/580153.html

相关文章:

  • 使用 C# 发送电子邮件(支持普通文本、HTML 和附件)
  • BEVFormer模型处理流程
  • 佰力博科技与您探讨表面电阻的测试方法及应用领域
  • Java程序员短时间内如何精通Redis?
  • 基于大模型的强直性脊柱炎全周期预测与诊疗方案研究
  • Spring Boot + 本地部署大模型实现:安全性与可靠性保障
  • 基于Linux的Spark本地模式环境搭建实验指南
  • RabbitMQ 4.1.1初体验
  • Ubuntu Linux Cursor 安装与使用一
  • Web前端数据可视化:ECharts高效数据展示完全指南
  • 【C#】入门
  • Linux三剑客:grep、sed、awk 详解以及find区别
  • 大语言模型预训练数据——数据采样方法介绍以GPT3为例
  • 基于Apache MINA SSHD配置及应用
  • CppCon 2018 学习:OOP is dead, long live Data-oriented design
  • ABP VNext + RediSearch:微服务级全文检索
  • PyCharm 安装使用教程
  • Rust异步爬虫实现与优化
  • 全星 QMS:制造业全面质量管理的数字化全能平台
  • 鸿蒙系统(HarmonyOS)应用开发之手势锁屏密码锁(PatternLock)
  • Jenkins-Publish HTML reports插件
  • 接口测试之postman
  • ZigBee通信技术全解析:从协议栈到底层实现,全方位解读物联网核心无线技术
  • 区块链技术核心组件及应用架构的全面解析
  • 7.4_面试_JAVA_
  • 【PyTorch】PyTorch预训练模型缓存位置迁移,也可拓展应用于其他文件的迁移
  • 基于PHP+MySQL实现(Web)英语学习与测试平台
  • 408第三季part2 - 计算机网络 - 计算机网络基本概念
  • 金融平衡术:创新与合规的突围之路
  • Spark从入门到实战:安装与使用全攻略