当前位置: 首页 > news >正文

huggingface笔记:文本生成Text generation

1 加载LLM模型

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import osmodel = AutoModelForCausalLM.from_pretrained("gpt2",device_map="auto",  # 自动分配到所有可用设备(优先 GPU)torch_dtype=torch.bfloat16
)

2 编码输入并生成文本

tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left")
model_inputs = tokenizer(["A list of colors: red, blue"], return_tensors="pt")
model_inputs
'''
{'input_ids': tensor([[  32, 1351,  286, 7577,   25, 2266,   11, 4171]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}
'''

2.1 调用 generate() 并使用 batch_decode() 还原文本:'

generated_ids = model.generate(**model_inputs)
generated_ids
'''
tensor([[  32, 1351,  286, 7577,   25, 2266,   11, 4171,   11, 4077,   11, 4171,11, 7872,   11, 7872,   11, 4077,   11, 4171,   11, 7872,   11, 4077,11, 4171,   11, 7872]])
'''
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
#A list of colors: red, blue, green, blue, yellow, yellow, green, blue, yellow, green, blue, yellow

3 常用参数

max_new_tokens最大生成 token 数
do_sample

是否使用采样策略(默认为False)

根据词表中每个 token 的概率随机抽取

num_beamsBeam search 会在每一步保留num_beams个候选序列(称为 beam),最终选择总体概率最高的那一条。
temperature

控制生成随机性(>0.8 适合创意任务,<0.4 更“严谨”)

需配合 do_sample=True

repetition_penalty>1 可减少重复内容
generated_ids = model.generate(**model_inputs,max_new_tokens=50,do_sample=True,temperature=0.9)
tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
'''
A list of colors: red, blue, white , yellow and purpleThere is a separate link in the sidebar of this page to see how this affects the colors of the text. Click the "Color Information" button. Then click "Next" to add this color information to your
'''

http://www.lryc.cn/news/582744.html

相关文章:

  • 【Node.js】文本与 pdf 的相互转换
  • 在 Linux(openEuler 24.03 LTS-SP1)上安装 Kubernetes + KubeSphere 的防火墙放行全攻略
  • 京东携手HarmonyOS SDK首发家电AR高精摆放功能
  • 代码详细注释:嵌入式Linux LCD汉字显示程序(基于font.h字库头文件)
  • 移动机器人的认知进化:Deepoc大模型重构寻迹本质
  • 数据库表设计:图片存储与自定义数据类型的实战指南
  • FlashAttention 深入浅出
  • C++STL详解(一):string类
  • Spring Boot:影响事务回滚的几种情况
  • Java List 使用详解:从入门到精通
  • 联通线路物理服务器选择的关键要点
  • 短视频矩阵系统的崛起:批量发布功能与多平台矩阵的未来
  • Redis基础学习(五大值数据类型的常用操作命令)
  • 中韩SD-WAN网络加速专线:提升国内与韩国公司网络性能的关键
  • ThreadPoolTaskExecutor 的使用案例
  • 东南亚主播解决方案|东南亚 TikTok 直播专线:纯净住宅 IP 、直播不卡顿
  • 分布式理论:CAP、Base理论
  • iOS打包流程
  • C++11 算法详解:std::copy_if 与 std::copy_n
  • 库制作与原理
  • Web前端开发: :where(伪类函数选择器)
  • Python之--列表
  • 实时音视频通过UDP打洞实现P2P优先通信
  • Python爬虫实战:研究python-nameparser库相关技术
  • nvm npm nrm 使用教程
  • Crazyflie支持MATLAB/Simulink控制 基于NOKOV度量动捕系统实现
  • 安装 asciidoctor-vscode 最新版
  • 【Python篇】PyCharm 安装与基础配置指南
  • Spring AI 基本组件详解 —— ChatClient、Prompt、Memory
  • Fiddler-关于抓取Android手机包,安装证书后页面加载失败,提示当前证书不可信存在安全风险的问题