当前位置: 首页 > news >正文

关于MediaEval数据集的Dataset构建(Text部分-使用PLM BERT)

import random
import numpy as np
import pandas as pd
import torch
from transformers import BertModel,BertTokenizer
from tqdm.auto import tqdm
from torch.utils.data import Dataset
import re
"""参考Game-On论文"""
"""util.py"""
def set_seed(seed_value=42):random.seed(seed_value)np.random.seed(seed_value)# 用于设置生成随机数的种子torch.manual_seed(seed_value)torch.cuda.manual_seed_all(seed_value)
"""util.py""""""文本预处理-textGraph.py"""
# 文本DataSet类def text_preprocessing(text):"""- Remove entity mentions (eg. '@united')- Correct errors (eg. '&amp;' to '&')@param    text (str): a string to be processed.@return   text (Str): the processed string."""# Remove '@name'text = re.sub(r'(@.*?)[\s]', ' ', text)# Replace '&amp;' with '&'text = re.sub(r'&amp;', '&', text)# Remove trailing whitespacetext = re.sub(r'\s+', ' ', text).strip()# removes linkstext = re.sub(r'(?P<url>https?://[^\s]+)', r'', text)# remove @usernamestext = re.sub(r"\@(\w+)", "", text)# remove # from #tagstext = text.replace('#', '')return textclass TextDataset(Dataset):def __init__(self,df,tokenizer):# 包含推文的主文件框架self.df = df.reset_index(drop=True)# 使用的分词器self.tokenizer = tokenizerdef __len__(self):return len(self.df)def __getitem__(self, idx):if torch.is_tensor(idx):idx = idx.tolist()# 帖子的文本内容text = self.df['tweetText'][idx]# 作为唯一标识符的id ‘tweetId'unique_id = self.df['tweetId'][idx]# 创建一个空的列表来存储输出结果input_ids = []attention_mask = []# 使用tokenizer分词器encoded_sent = self.tokenizer.encode_plus(text = text_preprocessing(text), # 这里使用的是预处理的句子,而不是直接对原句子使用tokenizeradd_special_tokens=True,        # 添加[CLS]以及[SEP]等特殊词元max_length=512,                 # 最大截断长度padding='max_length',            # padding的最大长度return_attention_mask=True,     # 返回attention_masktruncation=True                 #)# 获取编码效果input_ids = encoded_sent.get('input_ids')# 获取attention_mask结果attention_mask = encoded_sent.get('attention_mask')# 将列表转换成张量input_ids = torch.tensor(input_ids)attention_mask =torch.tensor(attention_mask)return {'input_ids':input_ids,'attention_mask':attention_mask,'unique_id':unique_id}def store_data(bert,device,df,dataset,store_dir):lengths = []bert.eval()for idx in tqdm(range(len(df))):sample = dataset.__getitem__(idx)print('原始sample[input_ids]和sample[attention_mask]的维度:',sample['input_ids'].shape,sample['attention_mask'].shape)# 升维input_ids,attention_mask = sample['input_ids'].unsqueeze(0),sample['attention_mask'].unsqueeze(0)input_ids = input_ids.to(device)attention_mask = attention_mask.to(device)# 得到唯一标识属性unique_id = sample['unique_id']# 计算token的个数num_tokens = attention_mask.sum().detach().cpu().item()"""不生成新的计算图,而是只做权重更新"""with torch.no_grad():out = bert(input_ids=input_ids,attention_mask=attention_mask)# last_hidden_state.shape是(batch_size,sequence_length,hidden_size)out_tokens = out.last_hidden_state[:,1:num_tokens,:].detach().cpu().squeeze(0).numpy() # token向量# 保存token级别表示filename = f'{emed_dir}{unique_id}.npy'try:np.save(filename, out_tokens)print(f"文件{filename}保存成功")except FileNotFoundError:# 文件不存在,创建新文件并保存np.save(filename, out_tokens)print(f"文件{filename}创建成功并保存成功")lengths.append(num_tokens)## Save semantic/ whole text representation# 保存语义  也就是整个文本的表示out_cls = out.last_hidden_state[:,0,:].unsqueeze(0).detach().cpu().squeeze(0).numpy() ## cls vectorfilename = f'{emed_dir}{unique_id}_full_text.npy'# 尝试保存.npy文件,如果文件不存在则自动创建try:np.save(filename, out_cls)print(f"文件{filename}保存成功")except FileNotFoundError:# 文件不存在,创建新文件并保存np.save(filename, out_cls)print(f"文件{filename}创建成功并保存成功")return lengthsif __name__=='__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 根目录root_dir = "./dataset/image-verification-corpus-master/image-verification-corpus-master/mediaeval2015/"emed_dir = './Embedding_File'# 文件路径train_csv_name = "tweetsTrain.csv"test_csv_name = "tweetsTest.csv"# 加载PLM和分词器tokenizer = BertTokenizer.from_pretrained('./bert/')bert = BertModel.from_pretrained('./bert/', return_dict=True)bert = bert.to(device)# 用于存储每个推文的Embeddingstore_dir ="Embed_Post/"# 创建训练数据集的Embedding表示df_train = pd.read_csv(f'{root_dir}{train_csv_name}')df_train = df_train.dropna().reset_index(drop=True)# 训练数据集的编码结果train_dataset = TextDataset(df_train,tokenizer)lengths = store_data(bert, device, df_train, train_dataset, store_dir)## Create graph data for testing set# 为测试集创建Embedding表示df_test = pd.read_csv(f'{root_dir}{test_csv_name}')df_test = df_test.dropna().reset_index(drop=True)test_dataset = TextDataset(df_test, tokenizer)lengths = store_data(bert, device, df_test, test_dataset, store_dir)"""文本预处理-textGraph.py"""
http://www.lryc.cn/news/310843.html

相关文章:

  • QML学习之Text
  • 轮转数组(元素位置对调、数据的左旋、右旋)
  • 喜迎乔迁,开启新章 ▏易我科技新办公区乔迁庆典隆重举行
  • 多个地区地图可视化
  • 学习使用paddle来构造hrnet网络模型
  • Redis 多线程操作同一个Key如何保证一致性?
  • 单链表合并
  • 【如何像网吧一样弄个游戏菜单在家里】
  • CSS~~
  • Docker技术概论(1):Docker与虚拟化技术比较
  • alibabacloud学习笔记07(小滴课堂)
  • Ansible-Playbook
  • UE5常见问题处理笔记
  • docker中hyperf项目配置虚拟域名
  • PID闭环控制算法的学习与简单使用
  • 【无刷电机学习】光耦HCNR200基本原理及应用(资料摘抄)
  • 【LeetCode】1768_交替合并字符串_C
  • C#解析JSON
  • AI图像识别算法助力安全生产*提升风险监测效率---豌豆云
  • CSS技巧:实现两个div在同一行显示的方法
  • 【Unity】Node.js安装与配置环境
  • Vue3:使用 Composition API 不需要 Pinia
  • ExoPlayer 播放视频黑屏的解决方法
  • C语言初阶—数组
  • 飞桨(PaddlePaddle)数据预处理教程
  • MYSQL C++链接接口编程
  • 并发编程中常见的设计模式,c++多线程如何设计
  • 解决android studio build Output中文乱码
  • [云原生] K8s之pod进阶
  • [Unity3d] 网络开发基础【个人复习笔记/有不足之处欢迎斧正/侵删】