当前位置: 首页 > news >正文

计算语言模型计算每秒钟生成的token数量it/s

main() 函数的stream循环中,我们可以计算每秒钟生成的token数量,然后输出 it/s。在流式生成过程中,我们可以使用Python的time模块来计算速度。在测试时,生成速度会受到多个因素的影响,包括设备性能、模型大小、输入文本长度等。

import os
import torch
import platform
from colorama import Fore, Style
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
import timedef init_model():print("init model ...")model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Chat",torch_dtype=torch.float16,device_map="cuda",trust_remote_code=True)model.generation_config = GenerationConfig.from_pretrained("baichuan-inc/Baichuan-13B-Chat")tokenizer = AutoTokenizer.from_pretrained("baichuan-inc/Baichuan-13B-Chat",use_fast=False,trust_remote_code=True)return model, tokenizerdef clear_screen():if platform.system() == "Windows":os.system("cls")else:os.system("clear")print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")return []def main(stream=True):model, tokenizer = init_model()messages = clear_screen()while True:prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)if prompt.strip() == "exit":breakif prompt.strip() == "clear":messages = clear_screen()continueprint(Fore.CYAN + Style.BRIGHT + "\nBaichuan:" + Style.NORMAL, end='')if prompt.strip() == "stream":stream = not streamprint(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')continuemessages.append({"role": "user", "content": prompt})if stream:position = 0try:start_time = time.time()total_tokens = 0for response in model.chat(tokenizer, messages, stream=True):print(response[position:], end='', flush=True)position = len(response)total_tokens += len(tokenizer(response, return_tensors='pt')['input_ids'][0])if torch.backends.mps.is_available():torch.mps.empty_cache()end_time = time.time()elapsed_time = end_time - start_timetokens_per_second = total_tokens / elapsed_timeprint(f"\n\n生成速度:{tokens_per_second:.2f} tokens/s")except KeyboardInterrupt:passprint()else:response = model.chat(tokenizer, messages)print(response)if torch.backends.mps.is_available():torch.mps.empty_cache()messages.append({"role": "assistant", "content": response})print(Style.RESET_ALL)if __name__ == "__main__":main()

http://www.lryc.cn/news/108040.html

相关文章:

  • Clickhouse调研
  • 02.Redis实现添加缓存功能
  • 【1.2】Java微服务:SpringCloud概论
  • 右键文件夹 ------- 打开 vscode的方法
  • 小程序原生实现左右锚点联动
  • STM32 低功耗-睡眠模式
  • IDEA用Gradle构建项目时,lombok插件无效的解决办法
  • 基于方向编码的模板匹配算法matlab仿真
  • shell centos 7 一键部署 KVM软件脚本
  • 64 # 实现一个 http-server
  • HCIP作业3
  • 【测试学习三】软件测试的生命周期 BUG的相关知识
  • git rebase 的坑儿
  • SSM(Vue3+ElementPlus+Axios+SSM前后端分离)【四】
  • iPhone 8 Plus透明屏应用范围详解
  • 【前端面试手撕题】instanceof、Array.map、Array.filter、Array.reduce、_objectCreate
  • 8.物联网操作系统之事件标志组
  • [腾讯云Cloud Studio实战训练营]无门槛使用GPT+Cloud Studio辅助编程完成Excel自动工资结算
  • 局域网ssh登录windows自带Linux系统(WSL)踩坑记录
  • 2023-02-03——2023-08-03,半年以来与客服交流的记录【CSND 文章撰写 网站使用求解】客服咨询交流记录(长期更新ing)
  • DCL 操作
  • C++11移动构造函数详解
  • 【力扣】19. 删除链表的倒数第 N 个结点 <链表指针、快慢指针>
  • Vue3和typeScript路由传参
  • SQLserver数据库巡检脚本
  • Go Ethereum源码学习笔记 001 Geth Start
  • idea如何加快创建Maven项目的速度
  • 软件外包开发的GO开发框架
  • oracle会话打满
  • VSCode自定义闪烁光标