当前位置: 首页 > news >正文

【论文解读】GPT Understands, Too

一.论文

1.1 P-tuning

区别于之前的工作,这篇工作认为promote可以在句子中的任意位置起到作用,可以将它们插入上下文或目标中

上图中,左图是不使用任何操作,右图是选择在居首和目标前插入promote的embedding,插入promote的过程可以表示为

其中x代表一系列离散的输入令牌,y代表目标(可以理解为希望模型想要给你的回答),e()表示对应的embedding,其实就是将其参数化映射成为伪tokens,即

通过最小化这些参数

1.2 promote生成

嵌入的promote实际上可以理解为不一定离散不相互关联的,而实际上的promote其实应该是高度离散的且具有关联性的,因此作者选择使用双向长短期记忆网络(LSTM),激活函数和MLP来建模这种关系

在推理中,我们只需要输出嵌入h,并且可以丢弃LSTM头

二.代码

本质上是使用一个PromptEncoder来生成伪的embedding添加到原先的embedding中

2.1 训练

训练过程只更新promote_encoder中的参数

 2.1.1 PromptEncoder

在PTuneForLAMA中实例化了PromptEncoder

 PromptEncoder本质上是一个(嵌入 + LSTM + MLP)

import torch
import torch.nn as nnclass PromptEncoder(torch.nn.Module):def __init__(self, template, hidden_size, tokenizer, device, args):super().__init__()self.device = deviceself.spell_length = sum(template)self.hidden_size = hidden_sizeself.tokenizer = tokenizerself.args = args# ent embeddingself.cloze_length = templateself.cloze_mask = [[1] * self.cloze_length[0]  # first cloze+ [1] * self.cloze_length[1]  # second cloze+ [1] * self.cloze_length[2]  # third cloze]self.cloze_mask = torch.LongTensor(self.cloze_mask).bool().to(self.device)self.seq_indices = torch.LongTensor(list(range(len(self.cloze_mask[0])))).to(self.device)# embeddingself.embedding = torch.nn.Embedding(len(self.cloze_mask[0]), self.hidden_size).to(self.device)# LSTMself.lstm_head = torch.nn.LSTM(input_size=self.hidden_size,hidden_size=self.hidden_size // 2,num_layers=2,dropout=self.args.lstm_dropout,bidirectional=True,batch_first=True)self.mlp_head = nn.Sequential(nn.Linear(self.hidden_size, self.hidden_size),nn.ReLU(),nn.Linear(self.hidden_size, self.hidden_size))print("init prompt encoder...")def forward(self):input_embeds = self.embedding(self.seq_indices).unsqueeze(0)output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]).squeeze()return output_embeds

2.1.2 调用

在PTuneForLAMA的forward函数中调用了embed_input来实现

http://www.lryc.cn/news/233409.html

相关文章:

  • 组合式API_生命周期
  • WPF如何实现应用程序托盘
  • ERROR: column “xxxx.id“ must appear in the GROUP BY
  • 【C++ 学习 ㊲】- 五种特殊类的设计
  • 探索arkui(2)--- 布局(列表)--- 2(支持分组/实现响应滚动位置)
  • systemverilog:interface中端口方向理解
  • 【GUI】-- 08 JButton、JRadioButton、JCheckBox
  • 【postgresql】CentOS7 安装Pgweb
  • 基于python和定向爬虫的商品比价系统
  • 使用GPT-4训练数据微调GPT-3.5 RAG管道
  • 二十三种设计模式全面解析-深入解析模板方法模式的奇妙世界
  • 【Spring】加载properties文件
  • react中间件的理解
  • React函数组件状态Hook—useState《进阶-对象数组》
  • linux 网络 cat /proc/net/dev 查看测试网络丢包情况
  • 记录配置VS,使用opencv与Eigen
  • uart控制led与beep
  • Linux修改root密码
  • C/C++模板类模板与函数模板区别,以及用法详解
  • van-dialog弹窗异步关闭-校验表单
  • Dynamic Wallpaper 16.7中文版
  • ​如何使用ArcGIS Pro制作渐变河流效果
  • 《网络协议》06. HTTP 补充 · HTTPS · SSL/TLS
  • Python winreg将cmd/PowerShell(管理员)添加到右键菜单
  • redis运维(九)字符串(二)字符串过期时间
  • 【C++】多线程的学习笔记(3)——白话文版(bushi
  • kotlin--3.集合操作
  • 自动驾驶-BEV感知综述
  • 面试题-3
  • C++ Core Guidelines 中文版 GSL