当前位置: 首页 > news >正文

hg transformers pipeline使用

什么是hg transformers pipeline?

在Hugging Face的transformers库中,pipeline是一个高级API,它提供了一种简便的方式来使用预训练模型进行各种NLP任务,比如情感分析、文本生成、翻译、问答等。通过pipeline,你可以在几行代码内实现复杂的NLP任务。pipeline会自动加载用于指定任务的默认模型和tokenizer,如果需要,用户也可以指定使用特定的模型和tokenizer

在创建pipeline时,除了可以指定任务类型和模型外,还可以设置其他参数,比如使用的深度学习框架("pt"代表PyTorch,"tf"代表TensorFlow)、设备(CPU或GPU)、批量处理大小等 。pipeline背后的实现包括初始化Tokenizer、Model,并进行数据预处理

以下是对pipelines主要特点和功能的总结:

  1. 任务特定: Pipelines为多种NLP任务提供了特定的接口,如文本分类、命名实体识别、问答、文本生成、翻译、摘要和情感分析等。
  2. 模型自动加载: 用户无需关心背后的模型细节,pipelines会自动加载适合任务的预训练模型和tokenizer。
  3. 易于使用: Pipelines提供了简洁的API,用户只需几行代码即可加载模型并进行任务处理。
  4. 自动分词: Pipelines内部处理文本的分词,将文本转换为模型能理解的格式。
  5. 批处理: Pipelines支持批处理,可以同时处理多条文本数据。
  6. 动态调整: Pipelines可以根据输入数据的需要自动调整模型输入,如填充(padding)和截断(truncation)。
  7. 自定义模型和分词器: 用户可以指定自定义的模型和分词器,以适应特定的需求。
  8. 模型微调: 在使用pipelines进行任务之前,用户还可以对模型进行微调,以适应特定的数据集。
  9. 多语言支持: 许多pipelines支持多种语言,使得跨语言的NLP任务成为可能。
  10. 可扩展性: 用户可以根据自己的需求,使用pipelines作为构建块,构建更复杂的NLP流程。
  11. 性能优化: Pipelines针对常见用例进行了优化,以提供高性能的NLP任务处理。
  12. 错误处理: Pipelines提供了错误处理机制,以应对加载模型或处理文本时可能出现的问题。

通过使用pipelines,研究人员和开发者可以快速原型开发和部署NLP应用,而无需深入了解模型的内部工作原理。简而言之,pipelines是Hugging Face Transformers库中一个强大且灵活的工具,用于简化NLP任务的处理流程。

支持的任务分类

可用于音频、计算机视觉、自然语言处理和多模态任务

TASK_ALIASES = {  "sentiment-analysis": "text-classification",  "ner": "token-classification",  "vqa": "visual-question-answering",  "text-to-speech": "text-to-audio",  
}  
SUPPORTED_TASKS = {  "audio-classification": {  "impl": AudioClassificationPipeline,  "tf": (),  "pt": (AutoModelForAudioClassification,) if is_torch_available() else (),  "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}},  "type": "audio",  },  "automatic-speech-recognition": {  "impl": AutomaticSpeechRecognitionPipeline,  "tf": (),  "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (),  "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}},  "type": "multimodal",  },  "text-to-audio": {  "impl": TextToAudioPipeline,  "tf": (),  "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (),  "default": {"model": {"pt": ("suno/bark-small", "645cfba")}},  "type": "text",  },  "feature-extraction": {  "impl": FeatureExtractionPipeline,  "tf": (TFAutoModel,) if is_tf_available() else (),  "pt": (AutoModel,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("distilbert/distilbert-base-cased", "935ac13"),  "tf": ("distilbert/distilbert-base-cased", "935ac13"),  }  },  "type": "multimodal",  },  "text-classification": {  "impl": TextClassificationPipeline,  "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),  "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),  "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"),  },  },  "type": "text",  },  "token-classification": {  "impl": TokenClassificationPipeline,  "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (),  "pt": (AutoModelForTokenClassification,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),  "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"),  },  },  "type": "text",  },  "question-answering": {  "impl": QuestionAnsweringPipeline,  "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (),  "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),  "tf": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"),  },  },  "type": "text",  },  "table-question-answering": {  "impl": TableQuestionAnsweringPipeline,  "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (),  "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (),  "default": {  "model": {  "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"),  "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"),  },  },  "type": "text",  },  "visual-question-answering": {  "impl": VisualQuestionAnsweringPipeline,  "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (),  "tf": (),  "default": {  "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")},  },  "type": "multimodal",  },  "document-question-answering": {  "impl": DocumentQuestionAnsweringPipeline,  "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (),  "tf": (),  "default": {  "model": {"pt": ("impira/layoutlm-document-qa", "52e01b3")},  },  "type": "multimodal",  },  "fill-mask": {  "impl": FillMaskPipeline,  "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (),  "pt": (AutoModelForMaskedLM,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("distilbert/distilroberta-base", "ec58a5b"),  "tf": ("distilbert/distilroberta-base", "ec58a5b"),  }  },  "type": "text",  },  "summarization": {  "impl": SummarizationPipeline,  "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  "default": {  "model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "d769bba")}  },  "type": "text",  },  # This task is a special case as it's parametrized by SRC, TGT languages.  "translation": {  "impl": TranslationPipeline,  "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  "default": {  ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  },  "type": "text",  },  "text2text-generation": {  "impl": Text2TextGenerationPipeline,  "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (),  "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (),  "default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}},  "type": "text",  },  "text-generation": {  "impl": TextGenerationPipeline,  "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (),  "pt": (AutoModelForCausalLM,) if is_torch_available() else (),  "default": {"model": {"pt": ("openai-community/gpt2", "6c0e608"), "tf": ("openai-community/gpt2", "6c0e608")}},  "type": "text",  },  "zero-shot-classification": {  "impl": ZeroShotClassificationPipeline,  "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (),  "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("facebook/bart-large-mnli", "c626438"),  "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  },  "config": {  "pt": ("facebook/bart-large-mnli", "c626438"),  "tf": ("FacebookAI/roberta-large-mnli", "130fb28"),  },  },  "type": "text",  },  "zero-shot-image-classification": {  "impl": ZeroShotImageClassificationPipeline,  "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (),  "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("openai/clip-vit-base-patch32", "f4881ba"),  "tf": ("openai/clip-vit-base-patch32", "f4881ba"),  }  },  "type": "multimodal",  },  "zero-shot-audio-classification": {  "impl": ZeroShotAudioClassificationPipeline,  "tf": (),  "pt": (AutoModel,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("laion/clap-htsat-fused", "973b6e5"),  }  },  "type": "multimodal",  },  "conversational": {  "impl": ConversationalPipeline,  "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),  "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (),  "default": {  "model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")}  },  "type": "text",  },  "image-classification": {  "impl": ImageClassificationPipeline,  "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (),  "pt": (AutoModelForImageClassification,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("google/vit-base-patch16-224", "5dca96d"),  "tf": ("google/vit-base-patch16-224", "5dca96d"),  }  },  "type": "image",  },  "image-feature-extraction": {  "impl": ImageFeatureExtractionPipeline,  "tf": (TFAutoModel,) if is_tf_available() else (),  "pt": (AutoModel,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("google/vit-base-patch16-224", "3f49326"),  "tf": ("google/vit-base-patch16-224", "3f49326"),  }  },  "type": "image",  },  "image-segmentation": {  "impl": ImageSegmentationPipeline,  "tf": (),  "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (),  "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}},  "type": "multimodal",  },  "image-to-text": {  "impl": ImageToTextPipeline,  "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (),  "pt": (AutoModelForVision2Seq,) if is_torch_available() else (),  "default": {  "model": {  "pt": ("ydshieh/vit-gpt2-coco-en", "65636df"),  "tf": ("ydshieh/vit-gpt2-coco-en", "65636df"),  }  },  "type": "multimodal",  },  "object-detection": {  "impl": ObjectDetectionPipeline,  "tf": (),  "pt": (AutoModelForObjectDetection,) if is_torch_available() else (),  "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}},  "type": "multimodal",  },  "zero-shot-object-detection": {  "impl": ZeroShotObjectDetectionPipeline,  "tf": (),  "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (),  "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}},  "type": "multimodal",  },  "depth-estimation": {  "impl": DepthEstimationPipeline,  "tf": (),  "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (),  "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}},  "type": "image",  },  "video-classification": {  "impl": VideoClassificationPipeline,  "tf": (),  "pt": (AutoModelForVideoClassification,) if is_torch_available() else (),  "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}},  "type": "video",  },  "mask-generation": {  "impl": MaskGenerationPipeline,  "tf": (),  "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (),  "default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}},  "type": "multimodal",  },  "image-to-image": {  "impl": ImageToImagePipeline,  "tf": (),  "pt": (AutoModelForImageToImage,) if is_torch_available() else (),  "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}},  "type": "image",  },  
}

使用示例

简单使用示例

from transformers import pipeline  
from transformers.pipelines import get_supported_tasks  
import json  nlp = pipeline("sentiment-analysis")  
# 单次调用  
result = nlp("I hate you")[0]  
print(f"label: {result['label']}, score: {round(result['score'], 4)}")  
# label: NEGATIVE, with score: 0.9991  
result = nlp("I love you")[0]  
print(f"label: {result['label']}, score: {round(result['score'], 4)}")  
# label: POSITIVE, with score  # 多次调用  
result = nlp(["This restaurant is awesome", "This restaurant is awful"])  
print(json.dumps(result))  print(json.dumps(get_supported_tasks()))

执行的输出日志如下:

  • 因为未指定model,默认根据任务分类名称从hg下载对应的模型,sentiment-analysis任务对应的默认模型是:models–distilbert–distilbert-base-uncased-finetuned-sst-2-english,默认是af0f99b
  • 下载的model保存的默认目录是:C:\Users\用户名.cache\huggingface\hub\
  • 不建议在生产环境中不指定model及版本
python.exe Classification.py 
No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.warnings.warn(
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:157: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\wang\.cache\huggingface\hub\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-developmentwarnings.warn(message)
label: NEGATIVE, score: 0.9991
label: POSITIVE, score: 0.9999[{"label": "POSITIVE", "score": 0.9998743534088135}, {"label": "NEGATIVE", "score": 0.9996669292449951}]["audio-classification", "automatic-speech-recognition", "conversational", "depth-estimation", "document-question-answering", "feature-extraction", "fill-mask", "image-classification", "image-feature-extraction", "image-segmentation", "image-to-image", "image-to-text", "mask-generation", "ner", "object-detection", "question-answering", "sentiment-analysis", "summarization", "table-question-answering", "text-classification", "text-generation", "text-to-audio", "text-to-speech", "text2text-generation", "token-classification", "translation", "video-classification", "visual-question-answering", "vqa", "zero-shot-audio-classification", "zero-shot-classification", "zero-shot-image-classification", "zero-shot-object-detection"]

Pipeline batching

from transformers import pipeline  
from transformers.pipelines.pt_utils import KeyDataset  
import datasets  dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised")  pipe = pipeline(task="sentiment-analysis")  for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"):  print(out)

自定义数据集

from transformers import pipeline
from torch.utils.data import Dataset
from tqdm.auto import tqdmpipe = pipeline("text-classification", device=0)class MyDataset(Dataset):def __len__(self):return 5000def __getitem__(self, i):return "This is a test"dataset = MyDataset()for batch_size in [1, 8, 64, 256]:print("-" * 30)print(f"Streaming batch_size={batch_size}")for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):pass

文本summary

# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)# use t5 in tf
summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)

学习资料

  • https://huggingface.co/docs/transformers/main_classes/pipelines
  • https://huggingface.co/docs/transformers/task_summary#text-classification
http://www.lryc.cn/news/426600.html

相关文章:

  • 高性能内存对象缓存
  • 文件上传-CMS文件上传分析
  • 云原生日志Loki
  • 初阶数据结构之直接选择排序和快速排序
  • Java语言程序设计——篇十三(4)
  • 低代码: 组件库测试之渲染和元素获取,触发事件,更新表单,验证事件以及异步请求
  • 银河麒麟服务器操作系统Kylin-Server-V10-SP3-2403-Release-20240426-x86_64安装步骤
  • 2024年电赛H题全开源
  • Docker:宿主机可以ping通外网,docker容器内无法ping通外网之解决方法
  • bootchart抓Android系统启动各阶段性能数据
  • 使用 Node.js 和 Express 框架通过网页访问GPIO和嵌入式 Linux 系统中使用 GSM/3G/4G 模块
  • IT 行业的就业情况
  • 如何快速获取麒麟操作系统版本信息
  • git提交规范检查husky
  • LeetCode 919. 完全二叉树插入器
  • C++密码管理器
  • 算法【Java】 —— 滑动窗口
  • Spring Aware接口执行时机
  • android FD_SET_chk问题定位
  • Chapter 39 Python多线程编程
  • STM32(二):GPIO
  • 一文入门mysql 数据库
  • 通义千问( 四 ) Function Call 函数调用
  • 设置idea中放缩字体大小
  • frameworks 之getEvent指令
  • tensorboard显示一片空白解决方案
  • C#编程中,如何实现一个高效的数据排序算法?
  • LookupError: Resource averaged_perceptron_tagger not found.解决方案
  • Leetcode JAVA刷刷站(39)组合总和
  • Spring中AbstractAutowireCapableBeanFactory