当前位置: 首页 > news >正文

bert ranking pairwise demo

下面是用bert 训练pairwise rank 的 demo

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics import pairwise_distances_argmin_minclass PairwiseRankingDataset(Dataset):def __init__(self, sentence_pairs, tokenizer, max_length):self.input_ids = []self.attention_masks = []for pair in sentence_pairs:encoded_pair = tokenizer(pair, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')self.input_ids.append(encoded_pair['input_ids'])self.attention_masks.append(encoded_pair['attention_mask'])self.input_ids = torch.cat(self.input_ids, dim=0)self.attention_masks = torch.cat(self.attention_masks, dim=0)def __len__(self):return len(self.input_ids)def __getitem__(self, idx):input_id = self.input_ids[idx]attention_mask = self.attention_masks[idx]return input_id, attention_maskclass BERTPairwiseRankingModel(torch.nn.Module):def __init__(self, bert_model_name):super(BERTPairwiseRankingModel, self).__init__()self.bert = BertModel.from_pretrained(bert_model_name)self.dropout = torch.nn.Dropout(0.1)self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)def forward(self, input_ids, attention_mask):outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)pooled_output = self.dropout(outputs[1])logits = self.fc(pooled_output)return logits.squeeze()# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)# 示例输入数据
sentence_pairs = [('I like cats', 'I like dogs'),('The sun is shining', 'It is raining'),('Apple is a fruit', 'Car is a vehicle')
]# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5# 创建数据集和数据加载器
dataset = PairwiseRankingDataset(sentence_pairs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 初始化模型并加载预训练权重
model = BERTPairwiseRankingModel(bert_model_name)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)# 训练模型
model.train()for epoch in range(num_epochs):total_loss = 0for input_ids, attention_masks in dataloader:optimizer.zero_grad()logits = model(input_ids, attention_masks)# 计算损失函数(使用对比损失函数)pos_scores = logits[::2]  # 正样本分数neg_scores = logits[1::2]  # 负样本分数loss = torch.relu(1 - pos_scores + neg_scores).mean()total_loss += loss.item()loss.backward()optimizer.step()print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")# 推断模型
model.eval()with torch.no_grad():embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())# 输出结果
for i, pair in enumerate(sentence_pairs):pos_idx = pairwise_distances[0][2 * i]neg_idx = pairwise_distances[0][2 * i + 1]pos_dist = pairwise_distances[1][2 * i]neg_dist = pairwise_distances[1][2 * i + 1]print(f"Pair: {pair}")print(f"Positive example index: {pos_idx}, Distance: {pos_dist:.4f}")print(f"Negative example index: {neg_idx}, Distance: {neg_dist:.4f}")print()

http://www.lryc.cn/news/160269.html

相关文章:

  • GPT引领前沿与应用突破之GPT4科研实践技术与AI绘图
  • SpringBoot整合Swagger3
  • detectron2 install path
  • 如何将DHTMLX Suite集成到Scheduler Lightbox中?让项目管理更可控!
  • 什么是JVM常用调优策略?分别有哪些?
  • 《向量数据库指南》——向量数据库Milvus Cloud 2.3的可运维性:从理论到实践
  • select多选回显问题 (取巧~)
  • 光伏并网双向计量表ADL400
  • 十三、MySQL(DQL)语句执行顺序
  • 【高德地图】根据经纬度多边形的绘制(可绘制区域以及任意图形)
  • C++ std::pair and std::list \ std::array
  • C++的类型转换
  • 【Selenium2+python】自动化unittest生成测试报告
  • 【APISIX】W10安装APISIX
  • [Linux]动静态库
  • 2023高教社杯数学建模国赛C题思路解析+代码+论文
  • macos13 arm芯片(m2) 搭建hbase docker容器 并用flink通过自定义richSinkFunction写入数据到hbase
  • FLV封装格式
  • [NLP]LLM---FineTune自己的Llama2模型
  • git在linux情况下设置git 命令高亮
  • C++ 表驱动方法代替if-else
  • 2023国赛数学建模E题思路分析 - 黄河水沙监测数据分析
  • cadence后仿真/寄生参数提取/解决pin口提取不全的问题
  • Vue中实现3D得球自动旋转
  • 使用wkhtmltoimage实现生成长图分享
  • 新风机未来什么样?
  • python的几种数据类型的花样玩法(一)
  • python回调函数之获取jenkins构建结果
  • Docker底层实现
  • PY32F003F18之RS485通讯