当前位置: 首页 > news >正文

PaddleNLP使用Vicuna

LLaMA 模型

LLaMa 是一个大型语言模型,由 Meta 开源。它的全称是 Large Language Model Meta AI,参数量从 70 亿到 650 亿不等。例如,130 亿参数的 LLaMA 模型在大多数基准上可以胜过参数量达 1750 亿的 GPT-3,而且可以在单块 V100 GPU 上运行。而最大的 650 亿参数的 LLaMA 模型可以媲美谷歌的 Chinchilla-70B 和 PaLM-540B。

Vicuna 模型

Vicuna 是一个由 UC 伯克利、CMU、斯坦福等机构的学者联手发布的最新开源大模型。基于 Meta 开源的 LLaMA 大模型,使用 ShareGPT 平台上的用户共享对话数据微调而来。包含 7B 和 13B 两个型号的开源预训练模型。

在这里插入图片描述

下载模型

# 下载 Vicuna 7B
# !git lfs clone http://git.aistudio.baidu.com/180581/vicuna-7b-v1.1.git# 下载 Vicuna 13B
!git lfs clone http://git.aistudio.baidu.com/180581/vicuna-13b-v1.1.git

开发环境

!pip install --pre --upgrade paddlenlp -f https://www.paddlepaddle.org.cn/whl/paddlenlp.html --user
!pip install paddlepaddle-gpu==0.0.0.post112 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html --user

代码

import os
import glob
import paddlefrom tqdm import tqdm
from paddlenlp.transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizerpattern = 'paddle-model-?????-of-?????.pdparams'# Vicuna 7B
# ckpt_dir = 'vicuna-7b-v1.1'
# config_dict =  {
#     "hidden_size": 4096,
#     "initializer_range": 0.02,
#     "intermediate_size": 11008,
#     "max_position_embeddings": 2048,
#     "model_type": "llama",
#     "num_attention_heads": 32,
#     "num_hidden_layers": 32,
#     "rms_norm_eps": 1e-06,
#     "vocab_size": 32000,
#     "bos_token_id": 1,
#     "eos_token_id": 2,
#     "pad_token_id": 0,
#     "use_cache": True,
#     "use_recompute": False,
#     "use_flash_attention": False,
# }# Vicuna 13B
ckpt_dir = 'vicuna-13b-v1.1'
config_dict =  {"hidden_size": 5120,"initializer_range": 0.02,"intermediate_size": 13824,"max_position_embeddings": 2048,"model_type": "llama","num_attention_heads": 40,"num_hidden_layers": 40,"rms_norm_eps": 1e-06,"vocab_size": 32000,"bos_token_id": 1,"eos_token_id": 2,"pad_token_id": 0,"use_cache": True,"use_recompute": False,"use_flash_attention": False,
}paddle.set_default_dtype('float16')tokenizer = LlamaTokenizer.from_pretrained(ckpt_dir)config = LlamaConfig(**config_dict)model = LlamaForCausalLM(config)
model.eval()for name, layer in model.named_sublayers():if 'rotary_emb' in name:layer.inv_freq = layer.inv_freq.cast(paddle.float32)paddle.device.cuda.empty_cache()for file_path in tqdm(glob.glob(os.path.join(ckpt_dir, pattern))):params = paddle.load(file_path)assert model.set_dict(params)[1] == [], 'Load error.'del paramspaddle.device.cuda.empty_cache()input_text = input('USER: ')
prompt = f'''USER: {input_text}\n\nASSISTANT: '''
with paddle.no_grad():with paddle.amp.auto_cast(False, level='O2', dtype='float16'):while True:if input_text == 'exit':breakinputs = tokenizer(prompt, return_tensors="pd", return_attention_mask=True,return_position_ids=True)outputs = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, position_ids=inputs.position_ids, max_length=2048-inputs.input_ids.shape[1], min_length=0, decode_strategy="sampling",temperature=0.8, top_k=40, top_p=0.95, repetition_penalty=1.1,bos_token_id=tokenizer.bos_token_id,eos_token_id=tokenizer.eos_token_id,pad_token_id=tokenizer.pad_token_id,use_cache=True, use_fast=True, use_fp16_decoding=True)response = tokenizer.decode(outputs[0][0], skip_special_tokens=True)print('ASSISTANT: ' + response)input_text = input('USER: ')prompt += f'''{response}\n\nUSER: {input_text}\n\nASSISTANT: '''del inputsdel outputsdel responsepaddle.device.cuda.empty_cache()
http://www.lryc.cn/news/153225.html

相关文章:

  • jackson常用操作
  • ios ipa包上传需要什么工具
  • 科目1基础知识快速入门精简
  • 安卓逆向 - 某东app加密参数还原
  • Visual Studio(2022)生成链接过程的.map映射文件以及.map映射文件的内容说明
  • A. Gift Carpet
  • 技术科普:汽车开放系统架构AUTOSAR
  • 说说HTTP 和 HTTPS 有什么区别?
  • Pygame中Trivia游戏解析6-5
  • Java8新特性2——方法引用
  • Mac“其他文件”存放着什么?“其他文件”的清理方法
  • 46、TCP的“三次握手”
  • libudev 和 libusb 常见API分析
  • [dasctf]misc04
  • Scala的函数式编程与高阶函数,匿名函数,偏函数,函数的闭包、柯里化,抽象控制,懒加载等
  • Axure RP 8.1.0.3400(原型设计工具)
  • 企业微信、飞书、钉钉机器人消息发送工具类
  • 手撕 视觉slam14讲 ch7 / pose_estimation_3d2d.cpp (1)
  • Mac安装Dart时,Homebrew报错 Error: Failure while executing
  • SSM整合~
  • Self-supervised 3D Human Pose Estimation from a Single Image
  • ubuntu下cups部分场景
  • 通过geoserver imageMosic发布多张tif数据
  • 输出图元(四)8-2 OpenGL画点函数、OpenGL画线函数
  • java八股文
  • 算法通关村——解析堆的应用
  • 爬虫源码---爬取小猫猫交易网站
  • Python的由来和基础语法(一)
  • 使用maven创建springboot项目
  • MySQL 基本操作1