当前位置: 首页 > news >正文

DeepSeek-V2-Chat多卡推理(不考虑性能)

@TOC

本文演示了如何使用accelerate推理DeepSeek-V2-Chat(裁剪以后的模型,仅演示如何将权值拆到多卡)

代码

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from accelerate import init_empty_weights
import sys
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
from torch.cuda.amp import autocast
from torch.utils._python_dispatch import TorchDispatchMode
from dataclasses import dataclass
from typing import Any
import torch.cuda
import multiprocessing as mp@dataclass
class _ProfilerState:cls: Anyobject: Any = Noneclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parentself.op_index=0self.cvt_count=0def get_max_gpu_id(self,tensors):max_gpu_id = -1max_index = -1tensor_index=[]for i, tensor in enumerate(tensors):if not isinstance(tensor, torch.Tensor):continuetensor_index.append(i)if tensor.is_cuda:gpu_id = tensor.get_device()if gpu_id > max_gpu_id:max_gpu_id = gpu_idmax_index = iif max_gpu_id == -1:return None, None,tensor_indexreturn max_index, max_gpu_id,tensor_indexdef convert(self,op_type,tensor_list):index, gpu_id,tensor_index = self.get_max_gpu_id(tensor_list)if index is None:returnkeep_index=set(tensor_index)-set([index])device=torch.device(f"cuda:{gpu_id}")for i in keep_index:if tensor_list[i].device!=device:#print(f"{op_type} {i} {tensor_list[i].device} -> {device}")tensor_list[i].data=tensor_list[i].data.to(device,non_blocking=True)#卡间通信是串行的,所有多stream并不能充分提升性能def __torch_dispatch__(self, func, types, args=(),kwargs=None):func_packet = func._overloadpacketif kwargs is None:kwargs = {}op_type=f"{func}"self.op_index+=1if isinstance(args, list) or isinstance(args, tuple):self.convert(op_type,args)elif isinstance(args[0], list) or isinstance(args[0], tuple):self.convert(op_type,args[0])else:print(op_type)output= func(*args,**kwargs)return outputclass TorchDumper:def __init__(self,**kwargs):self.p= _ProfilerState(TorchDumpDispatchMode)self.kwargs=kwargsdef __enter__(self):if self.p.object is None:o = self.p.cls(self,**self.kwargs)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.objectmodel_name = "./models/deepseek-ai/DeepSeek-V2-Chat/"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
max_memory = {i: "23GB" for i in range(8)}
sys.path.insert(0,model_name)model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,attn_implementation="eager",torch_dtype=torch.bfloat16)
model=model.eval()no_split_module_classes = ['DeepseekV2MLP','DeepseekV2Attention']
#no_split_module_classes = ['DeepseekV2DecoderLayer']device_map = infer_auto_device_map(model,max_memory=max_memory,no_split_module_classes=no_split_module_classes,dtype='float16')model = dispatch_model(model, device_map=device_map)
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_idmessages = [{"role": "user", "content": "Write a piece of quicksort code in C++"} ]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
with TorchDumper():outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print(result)
http://www.lryc.cn/news/370654.html

相关文章:

  • 算法题day42(补5.28日卡:动态规划02)
  • 分治与递归
  • Spring中IOC容器
  • php redis分布式锁
  • kotlin 中的布尔
  • 有哪些ai聊天推荐?简单分享三款
  • Python第二语言(十、Python面向对象(上))
  • SolidWorks 2016 SP5安装教程
  • 为什么高考志愿只选计算机专业?
  • GPT大模型微调-提高垂直领域回答质量
  • 全网首发-Docker被封后的代理设置教程
  • 代码随想录算法训练营第五十七天|1143.最长公共子序列、1035.不相交的线、53. 最大子序和、392.判断子序列
  • RocketMQ事务性消息
  • mysql (事物)
  • kotlin 中的字符串
  • 网站线上模板建设的优缺点
  • 哲学家进餐问题
  • 无人机遥感在农林信息提取中的实现方法与GIS融合应用
  • 联想测开一面(电话面试)笔试60%
  • 【python】tkinter GUI开发: Button和Entry的应用实战探索
  • sm2证书生成(openssl3.0)
  • java计算年化利率
  • 深入理解ChatGPT工作原理
  • 在 Wed 中应用 MyBatis(同时使用MVC架构模式,以及ThreadLocal 事务控制)
  • Elasticsearch index 设置 false,为什么还可以被检索到?
  • 169. 多数元素
  • ADS基础教程19 - 电磁仿真(EM)基本概念和实操
  • LabVIEW RT环境中因字符串拼接导致的系统崩溃问题
  • 深层网络:层数多真的更好吗?
  • 【QT5】<知识点> QT常用知识(更新中)