hg transformers pipeline使用
什么是hg transformers pipeline?
在Hugging Face的transformers
库中,pipeline
是一个高级API,它提供了一种简便的方式来使用预训练模型进行各种NLP任务,比如情感分析、文本生成、翻译、问答等。通过pipeline
,你可以在几行代码内实现复杂的NLP任务。pipeline
会自动加载用于指定任务的默认模型和tokenizer,如果需要,用户也可以指定使用特定的模型和tokenizer
在创建pipeline
时,除了可以指定任务类型和模型外,还可以设置其他参数,比如使用的深度学习框架("pt"代表PyTorch,"tf"代表TensorFlow)、设备(CPU或GPU)、批量处理大小等 。pipeline
背后的实现包括初始化Tokenizer、Model,并进行数据预处理
以下是对pipelines
主要特点和功能的总结:
- 任务特定: Pipelines为多种NLP任务提供了特定的接口,如文本分类、命名实体识别、问答、文本生成、翻译、摘要和情感分析等。
- 模型自动加载: 用户无需关心背后的模型细节,pipelines会自动加载适合任务的预训练模型和tokenizer。
- 易于使用: Pipelines提供了简洁的API,用户只需几行代码即可加载模型并进行任务处理。
- 自动分词: Pipelines内部处理文本的分词,将文本转换为模型能理解的格式。
- 批处理: Pipelines支持批处理,可以同时处理多条文本数据。
- 动态调整: Pipelines可以根据输入数据的需要自动调整模型输入,如填充(padding)和截断(truncation)。
- 自定义模型和分词器: 用户可以指定自定义的模型和分词器,以适应特定的需求。
- 模型微调: 在使用pipelines进行任务之前,用户还可以对模型进行微调,以适应特定的数据集。
- 多语言支持: 许多pipelines支持多种语言,使得跨语言的NLP任务成为可能。
- 可扩展性: 用户可以根据自己的需求,使用pipelines作为构建块,构建更复杂的NLP流程。
- 性能优化: Pipelines针对常见用例进行了优化,以提供高性能的NLP任务处理。
- 错误处理: Pipelines提供了错误处理机制,以应对加载模型或处理文本时可能出现的问题。
通过使用pipelines,研究人员和开发者可以快速原型开发和部署NLP应用,而无需深入了解模型的内部工作原理。简而言之,pipelines是Hugging Face Transformers库中一个强大且灵活的工具,用于简化NLP任务的处理流程。
支持的任务分类
可用于音频、计算机视觉、自然语言处理和多模态任务
TASK_ALIASES = { "sentiment-analysis": "text-classification", "ner": "token-classification", "vqa": "visual-question-answering", "text-to-speech": "text-to-audio",
}
SUPPORTED_TASKS = { "audio-classification": { "impl": AudioClassificationPipeline, "tf": (), "pt": (AutoModelForAudioClassification,) if is_torch_available() else (), "default": {"model": {"pt": ("superb/wav2vec2-base-superb-ks", "372e048")}}, "type": "audio", }, "automatic-speech-recognition": { "impl": AutomaticSpeechRecognitionPipeline, "tf": (), "pt": (AutoModelForCTC, AutoModelForSpeechSeq2Seq) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/wav2vec2-base-960h", "55bb623")}}, "type": "multimodal", }, "text-to-audio": { "impl": TextToAudioPipeline, "tf": (), "pt": (AutoModelForTextToWaveform, AutoModelForTextToSpectrogram) if is_torch_available() else (), "default": {"model": {"pt": ("suno/bark-small", "645cfba")}}, "type": "text", }, "feature-extraction": { "impl": FeatureExtractionPipeline, "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), "default": { "model": { "pt": ("distilbert/distilbert-base-cased", "935ac13"), "tf": ("distilbert/distilbert-base-cased", "935ac13"), } }, "type": "multimodal", }, "text-classification": { "impl": TextClassificationPipeline, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"), "tf": ("distilbert/distilbert-base-uncased-finetuned-sst-2-english", "af0f99b"), }, }, "type": "text", }, "token-classification": { "impl": TokenClassificationPipeline, "tf": (TFAutoModelForTokenClassification,) if is_tf_available() else (), "pt": (AutoModelForTokenClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"), "tf": ("dbmdz/bert-large-cased-finetuned-conll03-english", "f2482bf"), }, }, "type": "text", }, "question-answering": { "impl": QuestionAnsweringPipeline, "tf": (TFAutoModelForQuestionAnswering,) if is_tf_available() else (), "pt": (AutoModelForQuestionAnswering,) if is_torch_available() else (), "default": { "model": { "pt": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"), "tf": ("distilbert/distilbert-base-cased-distilled-squad", "626af31"), }, }, "type": "text", }, "table-question-answering": { "impl": TableQuestionAnsweringPipeline, "pt": (AutoModelForTableQuestionAnswering,) if is_torch_available() else (), "tf": (TFAutoModelForTableQuestionAnswering,) if is_tf_available() else (), "default": { "model": { "pt": ("google/tapas-base-finetuned-wtq", "69ceee2"), "tf": ("google/tapas-base-finetuned-wtq", "69ceee2"), }, }, "type": "text", }, "visual-question-answering": { "impl": VisualQuestionAnsweringPipeline, "pt": (AutoModelForVisualQuestionAnswering,) if is_torch_available() else (), "tf": (), "default": { "model": {"pt": ("dandelin/vilt-b32-finetuned-vqa", "4355f59")}, }, "type": "multimodal", }, "document-question-answering": { "impl": DocumentQuestionAnsweringPipeline, "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (), "tf": (), "default": { "model": {"pt": ("impira/layoutlm-document-qa", "52e01b3")}, }, "type": "multimodal", }, "fill-mask": { "impl": FillMaskPipeline, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), "pt": (AutoModelForMaskedLM,) if is_torch_available() else (), "default": { "model": { "pt": ("distilbert/distilroberta-base", "ec58a5b"), "tf": ("distilbert/distilroberta-base", "ec58a5b"), } }, "type": "text", }, "summarization": { "impl": SummarizationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": { "model": {"pt": ("sshleifer/distilbart-cnn-12-6", "a4f8f3e"), "tf": ("google-t5/t5-small", "d769bba")} }, "type": "text", }, # This task is a special case as it's parametrized by SRC, TGT languages. "translation": { "impl": TranslationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": { ("en", "fr"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}}, ("en", "de"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}}, ("en", "ro"): {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}}, }, "type": "text", }, "text2text-generation": { "impl": Text2TextGenerationPipeline, "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), "default": {"model": {"pt": ("google-t5/t5-base", "686f1db"), "tf": ("google-t5/t5-base", "686f1db")}}, "type": "text", }, "text-generation": { "impl": TextGenerationPipeline, "tf": (TFAutoModelForCausalLM,) if is_tf_available() else (), "pt": (AutoModelForCausalLM,) if is_torch_available() else (), "default": {"model": {"pt": ("openai-community/gpt2", "6c0e608"), "tf": ("openai-community/gpt2", "6c0e608")}}, "type": "text", }, "zero-shot-classification": { "impl": ZeroShotClassificationPipeline, "tf": (TFAutoModelForSequenceClassification,) if is_tf_available() else (), "pt": (AutoModelForSequenceClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("FacebookAI/roberta-large-mnli", "130fb28"), }, "config": { "pt": ("facebook/bart-large-mnli", "c626438"), "tf": ("FacebookAI/roberta-large-mnli", "130fb28"), }, }, "type": "text", }, "zero-shot-image-classification": { "impl": ZeroShotImageClassificationPipeline, "tf": (TFAutoModelForZeroShotImageClassification,) if is_tf_available() else (), "pt": (AutoModelForZeroShotImageClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("openai/clip-vit-base-patch32", "f4881ba"), "tf": ("openai/clip-vit-base-patch32", "f4881ba"), } }, "type": "multimodal", }, "zero-shot-audio-classification": { "impl": ZeroShotAudioClassificationPipeline, "tf": (), "pt": (AutoModel,) if is_torch_available() else (), "default": { "model": { "pt": ("laion/clap-htsat-fused", "973b6e5"), } }, "type": "multimodal", }, "conversational": { "impl": ConversationalPipeline, "tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (), "pt": (AutoModelForSeq2SeqLM, AutoModelForCausalLM) if is_torch_available() else (), "default": { "model": {"pt": ("microsoft/DialoGPT-medium", "8bada3b"), "tf": ("microsoft/DialoGPT-medium", "8bada3b")} }, "type": "text", }, "image-classification": { "impl": ImageClassificationPipeline, "tf": (TFAutoModelForImageClassification,) if is_tf_available() else (), "pt": (AutoModelForImageClassification,) if is_torch_available() else (), "default": { "model": { "pt": ("google/vit-base-patch16-224", "5dca96d"), "tf": ("google/vit-base-patch16-224", "5dca96d"), } }, "type": "image", }, "image-feature-extraction": { "impl": ImageFeatureExtractionPipeline, "tf": (TFAutoModel,) if is_tf_available() else (), "pt": (AutoModel,) if is_torch_available() else (), "default": { "model": { "pt": ("google/vit-base-patch16-224", "3f49326"), "tf": ("google/vit-base-patch16-224", "3f49326"), } }, "type": "image", }, "image-segmentation": { "impl": ImageSegmentationPipeline, "tf": (), "pt": (AutoModelForImageSegmentation, AutoModelForSemanticSegmentation) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/detr-resnet-50-panoptic", "fc15262")}}, "type": "multimodal", }, "image-to-text": { "impl": ImageToTextPipeline, "tf": (TFAutoModelForVision2Seq,) if is_tf_available() else (), "pt": (AutoModelForVision2Seq,) if is_torch_available() else (), "default": { "model": { "pt": ("ydshieh/vit-gpt2-coco-en", "65636df"), "tf": ("ydshieh/vit-gpt2-coco-en", "65636df"), } }, "type": "multimodal", }, "object-detection": { "impl": ObjectDetectionPipeline, "tf": (), "pt": (AutoModelForObjectDetection,) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/detr-resnet-50", "2729413")}}, "type": "multimodal", }, "zero-shot-object-detection": { "impl": ZeroShotObjectDetectionPipeline, "tf": (), "pt": (AutoModelForZeroShotObjectDetection,) if is_torch_available() else (), "default": {"model": {"pt": ("google/owlvit-base-patch32", "17740e1")}}, "type": "multimodal", }, "depth-estimation": { "impl": DepthEstimationPipeline, "tf": (), "pt": (AutoModelForDepthEstimation,) if is_torch_available() else (), "default": {"model": {"pt": ("Intel/dpt-large", "e93beec")}}, "type": "image", }, "video-classification": { "impl": VideoClassificationPipeline, "tf": (), "pt": (AutoModelForVideoClassification,) if is_torch_available() else (), "default": {"model": {"pt": ("MCG-NJU/videomae-base-finetuned-kinetics", "4800870")}}, "type": "video", }, "mask-generation": { "impl": MaskGenerationPipeline, "tf": (), "pt": (AutoModelForMaskGeneration,) if is_torch_available() else (), "default": {"model": {"pt": ("facebook/sam-vit-huge", "997b15")}}, "type": "multimodal", }, "image-to-image": { "impl": ImageToImagePipeline, "tf": (), "pt": (AutoModelForImageToImage,) if is_torch_available() else (), "default": {"model": {"pt": ("caidas/swin2SR-classical-sr-x2-64", "4aaedcb")}}, "type": "image", },
}
使用示例
简单使用示例
from transformers import pipeline
from transformers.pipelines import get_supported_tasks
import json nlp = pipeline("sentiment-analysis")
# 单次调用
result = nlp("I hate you")[0]
print(f"label: {result['label']}, score: {round(result['score'], 4)}")
# label: NEGATIVE, with score: 0.9991
result = nlp("I love you")[0]
print(f"label: {result['label']}, score: {round(result['score'], 4)}")
# label: POSITIVE, with score # 多次调用
result = nlp(["This restaurant is awesome", "This restaurant is awful"])
print(json.dumps(result)) print(json.dumps(get_supported_tasks()))
执行的输出日志如下:
- 因为未指定model,默认根据任务分类名称从hg下载对应的模型,sentiment-analysis任务对应的默认模型是:models–distilbert–distilbert-base-uncased-finetuned-sst-2-english,默认是af0f99b
- 下载的model保存的默认目录是:C:\Users\用户名.cache\huggingface\hub\
- 不建议在生产环境中不指定model及版本
python.exe Classification.py
No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.warnings.warn(
D:\soft\anaconda3\envs\llm-demo\lib\site-packages\huggingface_hub\file_download.py:157: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\wang\.cache\huggingface\hub\models--distilbert--distilbert-base-uncased-finetuned-sst-2-english. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-developmentwarnings.warn(message)
label: NEGATIVE, score: 0.9991
label: POSITIVE, score: 0.9999[{"label": "POSITIVE", "score": 0.9998743534088135}, {"label": "NEGATIVE", "score": 0.9996669292449951}]["audio-classification", "automatic-speech-recognition", "conversational", "depth-estimation", "document-question-answering", "feature-extraction", "fill-mask", "image-classification", "image-feature-extraction", "image-segmentation", "image-to-image", "image-to-text", "mask-generation", "ner", "object-detection", "question-answering", "sentiment-analysis", "summarization", "table-question-answering", "text-classification", "text-generation", "text-to-audio", "text-to-speech", "text2text-generation", "token-classification", "translation", "video-classification", "visual-question-answering", "vqa", "zero-shot-audio-classification", "zero-shot-classification", "zero-shot-image-classification", "zero-shot-object-detection"]
Pipeline batching
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
import datasets dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised") pipe = pipeline(task="sentiment-analysis") for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"): print(out)
自定义数据集
from transformers import pipeline
from torch.utils.data import Dataset
from tqdm.auto import tqdmpipe = pipeline("text-classification", device=0)class MyDataset(Dataset):def __len__(self):return 5000def __getitem__(self, i):return "This is a test"dataset = MyDataset()for batch_size in [1, 8, 64, 256]:print("-" * 30)print(f"Streaming batch_size={batch_size}")for out in tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):pass
文本summary
# use bart in pytorch
summarizer = pipeline("summarization")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)# use t5 in tf
summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
学习资料
- https://huggingface.co/docs/transformers/main_classes/pipelines
- https://huggingface.co/docs/transformers/task_summary#text-classification