当前位置: 首页 > news >正文

SentenceTransformer使用多GPU加速向量化

文章目录

  • 前言
  • 代码


前言

当我们需要对大规模的数据向量化以存到向量数据库中时,且服务器上有多个GPU可以支配,我们希望同时利用所有的GPU来并行这一过程,加速向量化。

代码

就几行代码,不废话了

from sentence_transformers import SentenceTransformer#Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
if __name__ == '__main__':#Create a large list of 100k sentencessentences = ["This is sentence {}".format(i) for i in range(100000)]#Define the modelmodel = SentenceTransformer('all-MiniLM-L6-v2')#Start the multi-process pool on all available CUDA devicespool = model.start_multi_process_pool()#Compute the embeddings using the multi-process poolemb = model.encode_multi_process(sentences, pool)print("Embeddings computed. Shape:", emb.shape)#Optional: Stop the proccesses in the poolmodel.stop_multi_process_pool(pool)

注意:一定要加if __name__ == '__main__':这一句,不然报如下错:

RuntimeError: An attempt has been made to start a new process before thecurrent process has finished its bootstrapping phase.This probably means that you are not using fork to start yourchild processes and you have forgotten to use the proper idiomin the main module:if __name__ == '__main__':freeze_support()...The "freeze_support()" line can be omitted if the programis not going to be frozen to produce an executable.

其实官方已经给出代码啦,我只不过复制粘贴了一下,代码位置:computing_embeddings_multi_gpu.py

官方还给出了流式encode的例子,也是多GPU并行的,如下:

from sentence_transformers import SentenceTransformer, LoggingHandler
import logging
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdmlogging.basicConfig(format='%(asctime)s - %(message)s',datefmt='%Y-%m-%d %H:%M:%S',level=logging.INFO,handlers=[LoggingHandler()])#Important, you need to shield your code with if __name__. Otherwise, CUDA runs into issues when spawning new processes.
if __name__ == '__main__':#Set paramsdata_stream_size = 16384  #Size of the data that is loaded into memory at oncechunk_size = 1024  #Size of the chunks that are sent to each processencode_batch_size = 128  #Batch size of the model#Load a large dataset in streaming mode. more info: https://huggingface.co/docs/datasets/streamdataset = load_dataset('yahoo_answers_topics', split='train', streaming=True)dataloader = DataLoader(dataset.with_format("torch"), batch_size=data_stream_size)#Define the modelmodel = SentenceTransformer('all-MiniLM-L6-v2')#Start the multi-process pool on all available CUDA devicespool = model.start_multi_process_pool()for i, batch in enumerate(tqdm(dataloader)):#Compute the embeddings using the multi-process poolsentences = batch['best_answer']batch_emb = model.encode_multi_process(sentences, pool, chunk_size=chunk_size, batch_size=encode_batch_size)print("Embeddings computed for 1 batch. Shape:", batch_emb.shape)#Optional: Stop the proccesses in the poolmodel.stop_multi_process_pool(pool)

官方案例:computing_embeddings_streaming.py

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A800-SXM...  On   | 00000000:23:00.0 Off |                    0 |
| N/A   58C    P0   297W / 400W |  75340MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A800-SXM...  On   | 00000000:29:00.0 Off |                    0 |
| N/A   71C    P0   352W / 400W |  80672MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A800-SXM...  On   | 00000000:52:00.0 Off |                    0 |
| N/A   68C    P0   398W / 400W |  75756MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A800-SXM...  On   | 00000000:57:00.0 Off |                    0 |
| N/A   58C    P0   341W / 400W |  75994MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A800-SXM...  On   | 00000000:8D:00.0 Off |                    0 |
| N/A   56C    P0   319W / 400W |  70084MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A800-SXM...  On   | 00000000:92:00.0 Off |                    0 |
| N/A   70C    P0   354W / 400W |  76314MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A800-SXM...  On   | 00000000:BF:00.0 Off |                    0 |
| N/A   73C    P0   360W / 400W |  75876MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A800-SXM...  On   | 00000000:C5:00.0 Off |                    0 |
| N/A   57C    P0   364W / 400W |  80404MiB / 81920MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+

嘎嘎快啊

http://www.lryc.cn/news/189522.html

相关文章:

  • 架构师-软件工程习题选择题
  • springboot单独在指定地方输出sql
  • gpio内部结构(一)
  • 【C++14保姆级教程】变量模板,Labmda泛型
  • LLM - 旋转位置编码 RoPE 代码详解
  • Vue之VueX知识探索(一起了解关于VueX的新世界)
  • 提升吃鸡战斗力,分享顶级作战干货!
  • 【rust】cargo的概念和使用方法
  • MySQL数据库——SQL优化(2)-order by 优化、group by 优化
  • C++DAY43
  • 大模型的超级“外脑”——向量数据库解决大模型的三大挑战
  • opencv读取摄像头并读取时间戳
  • WebRTC 系列(四、多人通话,H5、Android、iOS)
  • uniapp 点击 富文本元素 图片 可以预览(非nvue)
  • 【2023年11月第四版教材】第24章《法律法规与标准规范》(合集篇)
  • 提升战斗力!吃鸡行家分享顶级游戏干货,助你轻松拿下绝地求生
  • C语言练习百题之宏#define命令
  • 阿里云存储I/O性能、IOPS和吞吐量是什么意思?
  • Linux知识点 -- 网络基础 -- 数据链路层
  • git服务器宕机后,怎么用本地仓库重新建立gitlab服务器(包括所有历史版本)
  • 华为云云耀云服务器L实例评测 | 实例使用教学之综合导览
  • Elasticsearch 高级查询用法
  • 网络架构介绍
  • 第53节——Redux Toolkit初识
  • AndroidStudio报错:Plugin with id ‘kotlin-android‘ not found.
  • 【ADB】借助ADB模拟滑动屏幕,并进行循环
  • BN体系理解——类封装复现
  • 请求和响应的概述
  • (深度学习快速入门)A Gentle Introduction to Graph Neural Networks 笔记
  • VIM指令