当前位置: 首页 > news >正文

基于tornado BELLE 搭建本地的web 服务

我的github

将BELLE 封装成web 后端服务,采用tornado 框架
import timeimport torch
import torch.nn as nnfrom gptq import *
from modelutils import *
from quant import *from transformers import AutoTokenizer
import sys
import json
#import lightgbm as lgb
import logging
import tornado.escape
import tornado.ioloop
import tornado.web
import traceback
DEV = torch.device('cuda:0')def get_bloom(model):import torchdef skip(*args, **kwargs):passtorch.nn.init.kaiming_uniform_ = skiptorch.nn.init.uniform_ = skiptorch.nn.init.normal_ = skipfrom transformers import BloomForCausalLMmodel = BloomForCausalLM.from_pretrained(model, torch_dtype='auto')model.seqlen = 2048return modeldef load_quant(model, checkpoint, wbits, groupsize):from transformers import BloomConfig, BloomForCausalLM config = BloomConfig.from_pretrained(model)def noop(*args, **kwargs):passtorch.nn.init.kaiming_uniform_ = noop torch.nn.init.uniform_ = noop torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half)transformers.modeling_utils._init_weights = Falsetorch.set_default_dtype(torch.half)model = BloomForCausalLM(config)torch.set_default_dtype(torch.float)model = model.eval()layers = find_layers(model)for name in ['lm_head']:if name in layers:del layers[name]make_quant(model, layers, wbits, groupsize)print('Loading model ...')if checkpoint.endswith('.safetensors'):from safetensors.torch import load_file as safe_loadmodel.load_state_dict(safe_load(checkpoint))else:model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cuda')))model.seqlen = 2048print('Done.')return modelimport argparse
from datautils import *parser = argparse.ArgumentParser()parser.add_argument('model', type=str,help='llama model to load'
)
parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16],help='#bits to use for quantization; use 16 for evaluating base model.'
)
parser.add_argument('--groupsize', type=int, default=-1,help='Groupsize to use for quantization; default uses full row.'
)
parser.add_argument('--load', type=str, default='',help='Load quantized model.'
)parser.add_argument('--text', type=str,help='hello'
)parser.add_argument('--min_length', type=int, default=10,help='The minimum length of the sequence to be generated.'
)parser.add_argument('--max_length', type=int, default=1024,help='The maximum length of the sequence to be generated.'
)parser.add_argument('--top_p', type=float , default=0.95,help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'
)parser.add_argument('--temperature', type=float, default=0.8,help='The value used to module the next token probabilities.'
)args = parser.parse_args()if type(args.load) is not str:args.load = args.load.as_posix()if args.load:model = load_quant(args.model, args.load, args.wbits, args.groupsize)
else:model = get_bloom(args.model)model.eval()model.to(DEV)
tokenizer = AutoTokenizer.from_pretrained(args.model)
print("Human:")inputs = 'Human: ' +'hello' + '\n\nAssistant:'
input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)
"""
with torch.no_grad():generated_ids = model.generate(input_ids,do_sample=True,min_length=args.min_length,max_length=args.max_length,top_p=args.top_p,temperature=args.temperature,)
print("Assistant:\n") 
print(tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant 
print("\n-------------------------------\n")"""
#python bloom_inference.py BELLE_BLOOM_GPTQ_4BIT  --temperature 1.2  --wbits 4 --groupsize 128 --load  BELLE_BLOOM_GPTQ_4BIT/bloom7b-2m-4bit-128g.pt
class GateAPIHandler(tornado.web.RequestHandler):def initialize(self):self.set_header("Content-Type", "application/text")self.set_header("Access-Control-Allow-Origin", "*")async def post(self):print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")postArgs = self.request.body_argumentsprint( postArgs)if (not 'status' in postArgs):return tornado.web.HTTPError(400)try:json_str = postArgs.get("status")[0]
#            req = json.loads(json_str)print(json_str)#logging.error("recieve time : {0} . player id : {1}".format(str(time.time()), str(req["playerID"])))inputs = 'Human: ' +json_str.decode('utf-8') + '\n\nAssistant:'input_ids = tokenizer.encode(inputs, return_tensors="pt").to(DEV)with torch.no_grad():generated_ids = model.generate(input_ids,do_sample=True,min_length=args.min_length,max_length=args.max_length,top_p=args.top_p,temperature=args.temperature,)print("Assistant:\n")answer=tokenizer.decode([el.item() for el in generated_ids[0]])[len(inputs):]print(answer) # generated_ids开头加上了bos_token,需要将inpu的内容截断,只输出Assistant result = {'belle':answer}pred_str = str(json.dumps(result))self.write(pred_str)#logging.error("callback time : {0} . player id : {1}, result:{2}".format(str(time.time()), str(playerID), pred_str))except Exception as e:logging.error("Error: {0}.".format(e))traceback.print_exc()raise tornado.web.HTTPError(500)def get(self):raise tornado.web.HTTPError(300)import logging
import tornado.autoreload
import tornado.ioloop
import tornado.options
import tornado.web
import tornado.httpserver
#import   itempredict
import argparse
from tornado.httpserver import HTTPServer#trace()
if __name__ == "__main__":tornado.options.define("port", default=8081,type=int, help="This is a port number",metavar=None, multiple=False, group=None, callback=None)tornado.options.parse_command_line()app = tornado.web.Application([(r"/", GateAPIHandler),])apiport = tornado.options.options.portapp.listen(apiport)logging.info("Start Gate API server on port {0}.".format(apiport))server = HTTPServer(app)server.start(1)#trace()#tornado.autoreload.start()tornado.ioloop.IOLoop.instance().start()
import base64
import json
import time
import requests
from utils.ops import read_wav_bytesURL = 'http://192.168.3.9:8081'#wav_bytes, sample_rate, channels, sample_width = read_wav_bytes('out.wav')
data = {'status': ' 如何理解黑格尔的 量变引起质变规律和否定之否定规律',}t0=time.time()
r = requests.post(URL,  data=data)
t1=time.time()
r.encoding='utf-8'result = json.loads(r.text)
print(result)
print('time:', t1-t0, 's')

在这里插入图片描述

http://www.lryc.cn/news/195332.html

相关文章:

  • 信息系统漏洞与风险管理制度
  • Hadoop3教程(十七):MapReduce之ReduceJoin案例分析
  • BAT026:删除当前目录及子目录下的空文件夹
  • nodejs+vue网课学习平台
  • Can Language Models Make Fun? A Case Study in Chinese Comical Crosstalk
  • 阿里云云服务器实例使用教学
  • promisify 是 Node.js 标准库 util 模块中的一个函数
  • ArcGIS在VUE框架中的构建思想
  • 【Overload游戏引擎细节分析】视图投影矩阵计算与摄像机
  • 什么是云原生?零基础学云原生难吗?
  • Ubuntu18.04下载安装基于使用QT的pcl1.13+vtk8.2,以及卸载
  • 7 使用Docker容器管理的tomcat容器中的项目连接mysql数据库
  • 双节前把我的网站重构了一遍
  • 基于 nodejs+vue网上考勤系统
  • 以数智化指标管理,驱动光伏能源行业的市场推进
  • lv8 嵌入式开发-网络编程开发 18 广播与组播的实现
  • 前端面试题个人笔记(后面继续更新完善)
  • 软件设计之工厂方法模式
  • 【Linux】shell运行原理及权限
  • OA系统和ERP系统有什么区别?
  • c语言之strcat函数使用和实现
  • Halo-Theme-Hao文档:如何设置导航栏?
  • 【Java学习之道】Java网络编程API介绍
  • [论文笔记]SimCSE
  • 设置按键中断,按键1按下,LED亮,再按一次,灭按键2按下,蜂鸣器响。再按一次,不响按键3按下,风扇转,再按一次,风扇停
  • 深拷贝和浅拷贝的主要区别
  • Git Cherry Pick的使用
  • vue3后台管理框架之基础配置
  • Easysearch压缩模式深度比较:ZSTD+source_reuse的优势分析
  • 扩散模型的系统性学习(一):DDPM的学习