当前位置: 首页 > news >正文

使用Bert预训练模型处理序列推荐任务

最近的工作有涉及该任务,整理一下思路以及代码细节。

流程

总体来说思路就是首先用预训练的bert模型,在训练集的序列上进行CLS任务。对序列内容(这里默认是token id的sequence)以0.3左右的概率进行随机mask,然后将相应sequence的attention mask(原来决定padding index)和label(也就是mask的ground truth)输入到bert model里面。

当然其中vocab.txt并不存在的token是需要add进去的,具体方法不再详述,网上例子很多,注意word embedding也需要初始化就行。

模型定义:
self.model = AutoModelForMaskedLM.from_pretrained('./bert')
模型的输入:
result = self.bert_model(tail_mask, attention_mask, labels)
得到模型训练的结果之后,要做一个选择:

(1)transformer的bert model可以输出要预测时间步的hidden state,可以选择取出对应的hidden state,其中需要在数据处理的时候记录下每个sequence的tail position,也就是要预测位置的idx。另外我认为既然要进行序列推荐,那么最后一个tail position的token表征一定是最重要的,所以需要对tail position的idx专门给个写死的mask,效果会好一些。然后与sequence中item的全集进行相似度的计算,再去算交叉熵loss。

bert_hidden = result.hidden_states[-1]
bert_seq_hidden = torch.zeros((self.args.batch_size, 312)).to(self.device)
for i in range(self.args.batch_size):bert_seq_hidden[i,:] = bert_hidden[i, tail_pos[i], :]
logits = torch.matmul(bert_seq_hidden, test_item_emb.transpose(0, 1))
main_loss = self.criterion(logits, targets)

(2)同时也可以result.loss直接数据mask prediction的loss,我理解这个loss面对的任务是我要求sequence中的各个token表征都要尽可能准确,都要考虑,(1)可能更加注重最后一个位置的标准的准确性。

然后在evaluate阶段,需要注意输入到模型的不再是tail_mask,而是仅仅mask掉tail token id的sequence,因为我们需要尽可能准确的序列信息,只需要保证要预测的存在mask就够了。

由于是推荐任务,而且bert得到的hidden state表征过于隐式,所以需要一定的个性化引导它进行训练。经过个人的实验也确实如此,而且结果相差很多。

以上就是我个人的总结经验,欢迎大家指点。

http://www.lryc.cn/news/110970.html

相关文章:

  • 将word每页页眉单独设置
  • rust怎么生成随机数?
  • python-Excel数据模型文档转为MySQL数据库建表语句(需要连接数据库)-工作小记
  • 406 · 和大于S的最小子数组
  • xray的 webhook如何把它Hook住?^(* ̄(oo) ̄)^
  • 浅析RabbitMQ死信队列
  • ELK 企业级日志分析系统(ElasticSearch、Logstash 和 Kiabana 详解)
  • 数学建模—多元线性回归分析
  • win10 64位 vs2017 qt5.12.6 pcl1.9.1 vtk8.1.1配置安装步骤
  • 【项目 计网1】4.1 网络结构模式 4.2MAC地址、IP地址、端口
  • uni-app:分页实现多选功能
  • 问道管理:沪指窄幅震荡跌0.18%,有色、汽车等板块走低
  • Kotlin 协程与 Flow
  • 设备管理系统与物联网的融合:实现智能化设备监控和维护
  • 三、从官方源码精简出第1个FreeRTOS
  • __call__函数的用法
  • golang定时任务库cron实践
  • Julia 流程控制
  • 问题解决方案
  • kubernetes基于helm部署gitlab-operator
  • ChatGPT在在线客服和呼叫中心中的应用如何?
  • C++多线程环境下的单例类对象创建
  • “深入解析JVM内部机制:从字节码到垃圾回收“
  • 音频系统项目与音频算法研究方向分类
  • 单例模式和工厂模式
  • 两个镜头、视野、分辨率不同的相机(rgb、红外)的视野校正
  • kettle 连接jdbc
  • PyTorch中加载模型权重 A匹配B|A不匹配B
  • @FeignClient指定多个url实现负载均衡
  • vue diff 双端比较算法