当前位置: 首页 > news >正文

用 TripletLoss 优化bert ranking

下面是 用 TripletLoss 优化bert ranking 的demo


import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics.pairwise import pairwise_distancesclass TripletRankingDataset(Dataset):def __init__(self, queries, positive_docs, negative_docs, tokenizer, max_length):self.input_ids_q = []self.attention_masks_q = []self.input_ids_p = []self.attention_masks_p = []self.input_ids_n = []self.attention_masks_n = []for query, pos_doc, neg_doc in zip(queries, positive_docs, negative_docs):encoded_query = tokenizer.encode_plus(query, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')encoded_pos_doc = tokenizer.encode_plus(pos_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')encoded_neg_doc = tokenizer.encode_plus(neg_doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')self.input_ids_q.append(encoded_query['input_ids'])self.attention_masks_q.append(encoded_query['attention_mask'])self.input_ids_p.append(encoded_pos_doc['input_ids'])self.attention_masks_p.append(encoded_pos_doc['attention_mask'])self.input_ids_n.append(encoded_neg_doc['input_ids'])self.attention_masks_n.append(encoded_neg_doc['attention_mask'])self.input_ids_q = torch.cat(self.input_ids_q, dim=0)self.attention_masks_q = torch.cat(self.attention_masks_q, dim=0)self.input_ids_p = torch.cat(self.input_ids_p, dim=0)self.attention_masks_p = torch.cat(self.attention_masks_p, dim=0)self.input_ids_n = torch.cat(self.input_ids_n, dim=0)self.attention_masks_n = torch.cat(self.attention_masks_n, dim=0)def __len__(self):return len(self.input_ids_q)def __getitem__(self, idx):input_ids_q = self.input_ids_q[idx]attention_mask_q = self.attention_masks_q[idx]input_ids_p = self.input_ids_p[idx]attention_mask_p = self.attention_masks_p[idx]input_ids_n = self.input_ids_n[idx]attention_mask_n = self.attention_masks_n[idx]return input_ids_q, attention_mask_q, input_ids_p, attention_mask_p, input_ids_n, attention_mask_nclass BERTTripletRankingModel(torch.nn.Module):def __init__(self, bert_model_name, hidden_size):super(BERTTripletRankingModel, self).__init__()self.bert = BertModel.from_pretrained(bert_model_name)self.dropout = torch.nn.Dropout(0.1)self.fc = torch.nn.Linear(hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = self.dropout(outputs[1])logits = self.fc(pooled_output)return logits.squeeze()def triplet_loss(anchor, positive, negative, margin):distance_positive = torch.nn.functional.pairwise_distance(anchor, positive)distance_negative = torch.nn.functional.pairwise_distance(anchor, negative)losses = torch.relu(distance_positive - distance_negative + margin)return torch.mean(losses)# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)# 示例输入数据
queries = ['I like cats', 'The sun is shining']
positive_docs = ['I like dogs', 'The weather is beautiful']
negative_docs = ['Snakes are dangerous', 'It is raining']# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5
margin = 1.0# 创建数据集和数据加载器
dataset = TripletRankingDataset(queries, positive_docs, negative_docs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型并加载预训练权重
model = BERTTripletRankingModel(bert_model_name, hidden_size=model.bert.config.hidden_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)# 训练模型
model.train()for epoch in range(num_epochs):total_loss = 0for input_ids_q, attention_masks_q, input_ids_p, attention_masks_p, input_ids_n, attention_masks_n in dataloader:optimizer.zero_grad()embeddings_q = model(inputids_q, attention_masks_q)embeddings_p = model(input_ids_p, attention_masks_p)embeddings_n = model(input_ids_n, attention_masks_n)loss = triplet_loss(embeddings_q, embeddings_p, embeddings_n, margin)total_loss += loss.item()loss.backward()optimizer.step()print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")# 推断模型
model.eval()with torch.no_grad():embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids_q)pairwise_distances = pairwise_distances(embeddings.numpy())# 输出结果
for i, query in enumerate(queries):print(f"Query: {query}")print("Documents:")for j, doc in enumerate(positive_docs):doc_idx = pairwise_distances[0][i * len(positive_docs) + j]doc_dist = pairwise_distances[1][i * len(positive_docs) + j]print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")print(f"Document: {doc}")print("")print("---------")

http://www.lryc.cn/news/160937.html

相关文章:

  • Tomcat安装及使用
  • 法国新法案强迫 Firefox 等浏览器审查网站
  • 开源电商项目 Mall:构建高效电商系统的终极选择
  • QT(9.1)对话框与事件处理
  • C++项目实战——基于多设计模式下的同步异步日志系统-③-前置知识补充-设计模式
  • C++ 新旧版本两种读写锁
  • ES6 字符串的repeat()方法
  • 【车载以太网测试从入门到精通】系列文章目录汇总
  • LLM推理优化技术综述:KVCache、PageAttention、FlashAttention、MQA、GQA
  • go开发之个微机器人的二次开发
  • 2023国赛数学建模B题思路代码 - 多波束测线问题
  • SpringAOP面向切面编程
  • A Guide to Java HashMap
  • LeetCode 449. Serialize and Deserialize BST【树,BFS,DFS,栈】困难
  • 嵌入式IDE(1):IAR中ICF链接文件详解和实例分析
  • 分布式版本控制工具——git
  • C基础-数组
  • springboot项目配置flyway菜鸟级别教程
  • 成都精灵云初试
  • css relative 和absolute布局
  • 更健康舒适更科技的照明体验!书客SKY护眼台灯SUKER L1上手体验
  • 经管博士科研基础【19】齐次线性方程组
  • django报错解决 Forbidden (403) CSRF verification failed. Request aborted.
  • k8s-实战——yapi平台部署
  • Excel VSTO开发5 -Excel对象结构
  • Javafx集成sqlite数据库
  • react-native实现 TextInput 键盘显示搜索按钮并触发回调
  • 人大金仓分析型数据库备份和恢复(五)
  • lenovo联想笔记本ThinkPad P16V Gen 1(21FC,21FD)原装出厂Win11系统
  • Django实现音乐网站 ⒃