当前位置: 首页 > news >正文

基于xiaothink对Wanyv-50M模型进行c-eval评估

使用pypi安装xiaothink:

pip install xiaothink==1.0.2

下载模型:
万语-50M


开始评估(修改模型路径后即可直接开始运行,结果保存在output文件夹里):

import os
import json
import pandas as pd
import re
from tqdm import tqdm
import random
import time
import requests
from xiaothink.llm.inference.test_formal import *
model=QianyanModel(MT=40.231,ckpt_dir=r'path\to\wanyv\model\ckpt_test_40_2_3_1_formal_open')def chat_x(inp,temp=0.3):return model.chat_SingleTurn(inp,temp=temp,loop=True,stop='。')#from collections import Counterdef pre(question: str, options_str: str) -> str:question = question.replace('答案:', '')options_str = options_str.replace('答案:', '')if not 'A' in question:#你只需要直接-让我们首先一步步思考,最后在回答末尾prompt_template = '''题目:{question}\n{options_str}\n让我们首先一步步思考,最后在回答末尾给出一个字母作为你的答案(A或B或C或D)'''prompt_template2 = '''题目:{question}\n选项:{options_str}\n给出答案'''prompt_template3 = '''{question}\n{options_str}\n'''prompt_template4 = '''{question}\n{options_str}\n给出你的选择'''prompt_template5 = '''题目:{question}\n{options_str}\n答案:'''else:prompt_template = '''题目:{question}\n让我们首先一步步思考,最后在回答末尾给出一个字母作为你的答案(A或B或C或D)'''prompt_template2 = '''题目:{question}\n给出答案'''prompt_template3 = '''{question}\n'''prompt_template4 = '''{question}\n给出你的选择'''prompt_template5 = '''题目:{question}\n答案:'''ansd={}# Run the chat_core function 5 times and collect answersanswers = []for _ in range(1):response = chat_x(prompt_template.format(question=question, options_str=options_str))#print(response)# Extract answer from responsefor option in 'ABCD':if option in response:answers.append(option)ansd[option]=responsebreakelse:print('AI选项检查:', repr(response))answers.append('A')  # Default to 'A' if no option foundansd['A']=''# Count occurrences of each answeranswer_counts = Counter(answers)# Find the most common answer(s)most_common_answers = answer_counts.most_common()highest_frequency = most_common_answers[0][1]most_frequent_answers = [answer for answer, count in most_common_answers if count == highest_frequency]# Choose one of the most frequent answers (if there's a tie, choose the first alphabetically)final_answer = min(most_frequent_answers)with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template2.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template3.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template4.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)with open('ceval_text_sklm.txt','a',encoding='utf-8') as f:f.write(
'{"instruction": "{prompt_template}", "input": "", "output": "{final_answer}"}\n'.replace('{prompt_template}',prompt_template5.format(question=question, options_str=options_str).replace('\n','\\n')).replace('{final_answer}',ansd[final_answer]),)return final_answerclass Llama_Evaluator:def __init__(self, choices, k):self.choices = choicesself.k = kdef eval_subject(self, subject_name,test_df,dev_df=None,few_shot=False,cot=False,save_result_dir=None,with_prompt=False,constrained_decoding=False,do_test=False):all_answers = {}correct_num = 0if save_result_dir:result = []score = []if few_shot:history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)else:history = ''answers = ['NA'] * len(test_df) if do_test is True else list(test_df['answer'])for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):question = self.format_example(row, include_answer=False, cot=cot, with_prompt=with_prompt)options_str = self.format_options(row)instruction = history + question + "\n选项:" + options_strans = pre(instruction, options_str)if ans == answers[row_index]:correct_num += 1correct = 1else:correct = 0print(f"\n=======begin {str(row_index)}=======")print("question: ", question)print("options: ", options_str)print("ans: ", ans)print("ground truth: ", answers[row_index], "\n")if save_result_dir:result.append(ans)score.append(correct)print(f"=======end {str(row_index)}=======")all_answers[str(row_index)] = anscorrect_ratio = 100 * correct_num / len(answers)if save_result_dir:test_df['model_output'] = resulttest_df['correctness'] = scoretest_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv'))return correct_ratio, all_answersdef format_example(self, line, include_answer=True, cot=False, with_prompt=False):example = line['question']for choice in self.choices:example += f'\n{choice}. {line[f"{choice}"]}'if include_answer:if cot:example += "\n答案:让我们一步一步思考,\n" + \line["explanation"] + f"\n所以答案是{line['answer']}。\n\n"else:example += '\n答案:' + line["answer"] + '\n\n'else:if with_prompt is False:if cot:example += "\n答案:让我们一步一步思考,\n1."else:example += '\n答案:'else:if cot:example += "\n答案是什么?让我们一步一步思考,\n1."else:example += '\n答案是什么? 'return exampledef generate_few_shot_prompt(self, subject, dev_df, cot=False):prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"k = self.kif self.k == -1:k = dev_df.shape[0]for i in range(k):prompt += self.format_example(dev_df.iloc[i, :],include_answer=True,cot=cot)return promptdef format_options(self, line):options_str = ""for choice in self.choices:options_str += f"{choice}: {line[f'{choice}']} "return options_strdef main(model_path, output_dir, take, few_shot=False, cot=False, with_prompt=False, constrained_decoding=False, do_test=False, n_times=1, do_save_csv=False):assert os.path.exists("subject_mapping.json"), "subject_mapping.json not found!"with open("subject_mapping.json") as f:subject_mapping = json.load(f)filenames = os.listdir("data/val")subject_list = [val_file.replace("_val.csv", "") for val_file in filenames]accuracy, summary = {}, {}run_date = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))save_result_dir = os.path.join(output_dir, f"take{take}")if not os.path.exists(save_result_dir):os.makedirs(save_result_dir, exist_ok=True)evaluator = Llama_Evaluator(choices=choices, k=n_times)all_answers = {}for index, subject_name in tqdm(list(enumerate(subject_list)),desc='主进度'):print(f"{index / len(subject_list)} Inference starts at {run_date} on {model_path} with subject of {subject_name}!")val_file_path = os.path.join('data/val', f'{subject_name}_val.csv')dev_file_path = os.path.join('data/dev', f'{subject_name}_dev.csv')test_file_path = os.path.join('data/test', f'{subject_name}_test.csv')val_df = pd.read_csv(val_file_path) if not do_test else pd.read_csv(test_file_path)dev_df = pd.read_csv(dev_file_path) if few_shot else Nonecorrect_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df,save_result_dir=save_result_dir if do_save_csv else None,few_shot=few_shot,cot=cot,with_prompt=with_prompt,constrained_decoding=constrained_decoding,do_test=do_test)print(f"Subject: {subject_name}")print(f"Acc: {correct_ratio}")accuracy[subject_name] = correct_ratiosummary[subject_name] = {"score": correct_ratio,"num": len(val_df),"correct": correct_ratio * len(val_df) / 100}all_answers[subject_name] = answersjson.dump(all_answers, open(save_result_dir + '/submission.json', 'w'), ensure_ascii=False, indent=4)print("Accuracy:")for k, v in accuracy.items():print(k, ": ", v)total_num = 0total_correct = 0summary['grouped'] = {"STEM": {"correct": 0.0, "num": 0},"Social Science": {"correct": 0.0, "num": 0},"Humanities": {"correct": 0.0, "num": 0},"Other": {"correct": 0.0, "num": 0}}for subj, info in subject_mapping.items():group = info[2]summary['grouped'][group]["num"] += summary[subj]['num']summary['grouped'][group]["correct"] += summary[subj]['correct']for group, info in summary['grouped'].items():info['score'] = info["correct"] / info["num"]total_num += info["num"]total_correct += info["correct"]summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct}json.dump(summary, open(save_result_dir + '/summary.json', 'w'), ensure_ascii=False, indent=2)# Example usage
if __name__ == "__main__":model_path = "path/to/model"output_dir = "output"take = 0few_shot = Falsecot = Falsewith_prompt = Falseconstrained_decoding = Falsedo_test = True#Falsen_times = 1do_save_csv = Falsemain(model_path, output_dir, take, few_shot, cot, with_prompt, constrained_decoding, do_test, n_times, do_save_csv)
http://www.lryc.cn/news/508207.html

相关文章:

  • 使用k6进行kafka负载测试
  • Unity A*算法实现+演示
  • 浏览器要求用户确认 Cookies Privacy(隐私相关内容)是基于隐私法规的要求,VUE 实现,html 代码
  • 如何设计高效的商品系统并提升扩展性:从架构到实践的全方位探索
  • 使用计算机创建一个虚拟世界
  • datasets笔记:两种数据集对象
  • 【ETCD】【Linearizable Read OR Serializable Read】ETCD 数据读取:强一致性 vs 高性能,选择最适合的读取模式
  • 【CSS in Depth 2 精译_089】15.2:CSS 过渡特效中的定时函数
  • 不常用命令指南
  • spring mvc | servlet :serviceImpl无法自动装配 UserMapper
  • STM32 HAL库之串口接收不定长字符
  • Pyqt6的tableWidget填充数据
  • ASP.NET Core - 依赖注入 自动批量注入
  • UVM 验证方法学之interface学习系列文章(十一)virtual interface 再续篇
  • 面试题整理5----进程、线程、协程区别及僵尸进程处理
  • OpenTK 中帧缓存的深度解析与应用实践
  • 第2节-Test Case如何调用Object Repository中的请求并关联参数
  • 【HarmonyOS NEXT】Web 组件的基础用法以及 H5 侧与原生侧的双向数据通讯
  • Android学习(六)-Kotlin编程语言-数据类与单例类
  • CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
  • 力扣274. H 指数
  • 挑战一个月基本掌握C++(第五天)了解运算符,循环,判断
  • Python的sklearn中的RandomForestRegressor使用详解
  • ReactPress 1.6.0:重塑博客体验,引领内容创新
  • 人脸生成3d模型 Era3D
  • kubeadm搭建k8s集群
  • centOS系统进程管理基础知识
  • STM32中ADC模数转换器
  • 初学stm32 --- 外部中断
  • wordpress调用指定分类ID下 相同标签的内容