当前位置: 首页 > news >正文

PaddleNLP进行Bart文本摘要训练

也可以换成Pegasus,T5模型,本质上是一样的。

!pip install paddlenlp==2.3.2 
# !pip install rouge==1.0.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
# !pip install regex
import os
import json
import argparse
import random
import time
import distutils.util
from pprint import pprint
from functools import partial
from tqdm import tqdm
import numpy as np
import math
from datasets import load_dataset
import contextlib
from rouge import Rouge
from visualdl import LogWriterimport paddle
import paddle.nn as nn
from paddle.io import BatchSampler, DistributedBatchSampler, DataLoader
from paddlenlp.transformers import BartForConditionalGeneration ,BartTokenizer
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.utils.log import logger
from paddlenlp.metrics import BLEU
from paddlenlp.data import DataCollatorForSeq2Seq
max_source_length = 128
# 摘要的最大长度
max_target_length = 64
min_target_length=16
num_beams=4
def infer(text, model, tokenizer):tokenized = tokenizer(text, truncation=True, max_length=max_source_length, return_tensors='pd')preds, _ = model.generate(input_ids=tokenized['input_ids'],max_length=max_target_length,min_length=min_target_length,decode_strategy='beam_search',num_beams=num_beams)return tokenizer.decode(preds[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
model = BartForConditionalGeneration.from_pretrained('bart-large')
tokenizer = BartTokenizer.from_pretrained('bart-large')text = ' 根据实名举报者小农的讲述,她的表妹小华于2019年9月被诊断为重度抑郁症,多年治疗未见好转。直至今年1月17日,因无法承受心理压力,小华在家自杀身亡。家属在整理遗物时,发现她的日记、聊天记录等多次提到,自己遭高中教师唐某某的性侵,还遭到唐某某的言语威胁。'
infer(text, model, tokenizer)

1. 下面开始使用自己的数据进行训练微调

train_dataset = load_dataset("json", data_files='./train2.json', split="train")
dev_dataset = load_dataset("json", data_files='./valid2.json', split="train")

下面是数据格式:

{"content":"霸饱巴辨倍步败巴傲澳颁秉步去伴。边斑八奥把。","title":"澳班奔爆财悲暴包臣扮。"}
{"content":"霸傲饱辨倍巴傲救颁秉步败班摆半瓣拔池畅版变。","title":"澳班傲半瓣厦摆并拔编导伴查。"}
{"content":"便傲必彼被杯奥并拔池碑版变。","title":"板傲必彼被杯碧奈暴伴查闭搬傲并拔编。"}
def convert_example(example, text_column, summary_column, tokenizer,max_source_length, max_target_length):"""构造模型的输入."""inputs = example[text_column]targets = example[summary_column]# 分词model_inputs = tokenizer(inputs,max_length=max_source_length,padding=False,truncation=True,return_attention_mask=True)labels = tokenizer(targets,max_length=max_target_length,padding=False,truncation=True)# 得到labels,后续通过DataCollatorForSeq2Seq进行移位model_inputs["labels"] = labels["input_ids"]return model_inputs
# 原始字段需要移除
remove_columns = ['content', 'title']
# 文本的最大长度
max_source_length = 128
# 摘要的最大长度
max_target_length = 64
# 定义转换器
trans_func = partial(convert_example,text_column='content',summary_column='title',tokenizer=tokenizer,max_source_length=max_source_length,max_target_length=max_target_length)# train_dataset和dev_dataset分别转换
train_dataset = train_dataset.map(trans_func,batched=True,load_from_cache_file=True,remove_columns=remove_columns)
dev_dataset = dev_dataset.map(trans_func,batched=True,load_from_cache_file=True,remove_columns=remove_columns)# 输出训练集的前 3 条样本
for idx, example in enumerate(dev_dataset):if idx < 3:print(example) 
# 组装 Batch 数据 & Padding
batchify_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)# 分布式批采样器,用于多卡分布式训练
train_batch_sampler = DistributedBatchSampler(train_dataset, batch_size=12, shuffle=True)# 构造训练Dataloader
train_data_loader = DataLoader(dataset=train_dataset,batch_sampler=train_batch_sampler,num_workers=0,collate_fn=batchify_fn,return_list=True)dev_batch_sampler = BatchSampler(dev_dataset,batch_size=48,shuffle=False)
# 构造验证Dataloader
dev_data_loader = DataLoader(dataset=dev_dataset,batch_sampler=dev_batch_sampler,num_workers=0,collate_fn=batchify_fn,return_list=True)
# 学习率预热比例
warmup = 0.02
# 学习率
learning_rate = 5e-5
# 训练轮次
num_epochs = 10
# 训练总步数
num_training_steps = 6000
# AdamW优化器参数epsilon
adam_epsilon = 1e-6
# AdamW优化器参数weight_decay
weight_decay=0.01
# 训练中,每个log_steps打印一次日志
log_steps = 100
# 训练中,每隔eval_steps进行一次模型评估
eval_steps = 1000
# 摘要的最小长度
min_target_length = 6
# 训练模型保存路径
output_dir = 'checkpoints'
# 解码beam size
num_beams = 4log_writer = LogWriter('visualdl_log_dir')
lr_scheduler = LinearDecayWithWarmup(learning_rate, num_training_steps, warmup)# LayerNorm参数不参与weight_decay
decay_params = [p.name for n, p in model.named_parameters()if not any(nd in n for nd in ["bias", "norm"])
]
# 优化器AdamW
optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,beta1=0.9,beta2=0.999,epsilon=adam_epsilon,parameters=model.parameters(),weight_decay=weight_decay,apply_decay_param_fun=lambda x: x in decay_params)
# 计算训练评估参数Rouge-1,Rouge-2,Rouge-L,BLEU-4
def compute_metrics(preds, targets):assert len(preds) == len(targets), ('The length of pred_responses should be equal to the length of ''target_responses. But received {} and {}.'.format(len(preds), len(targets)))rouge = Rouge()bleu4 = BLEU(n_size=4)scores = []for pred, target in zip(preds, targets):try:score = rouge.get_scores(' '.join(pred), ' '.join(target))scores.append([score[0]['rouge-1']['f'], score[0]['rouge-2']['f'],score[0]['rouge-l']['f']])except ValueError:scores.append([0, 0, 0])bleu4.add_inst(pred, [target])rouge1 = np.mean([i[0] for i in scores])rouge2 = np.mean([i[1] for i in scores])rougel = np.mean([i[2] for i in scores])bleu4 = bleu4.score()print('\n' + '*' * 15)print('The auto evaluation result is:')print('rouge-1:', round(rouge1*100, 2))print('rouge-2:', round(rouge2*100, 2))print('rouge-L:', round(rougel*100, 2))print('BLEU-4:', round(bleu4*100, 2))return rouge1, rouge2, rougel, bleu4
# 模型评估函数
@paddle.no_grad()
def evaluate(model, data_loader, tokenizer, min_target_length,max_target_length):model.eval()all_preds = []all_labels = []model = model._layers if isinstance(model, paddle.DataParallel) else modelfor batch in tqdm(data_loader, total=len(data_loader), desc="Eval step"):labels = batch.pop('labels').numpy()# 模型生成preds = model.generate(input_ids=batch['input_ids'],attention_mask=batch['attention_mask'],min_length=min_target_length,max_length=max_target_length,diversity_rate='beam_search',num_beams=num_beams,use_cache=True)[0]# tokenizer将id转为stringall_preds.extend(tokenizer.batch_decode(preds.numpy(),skip_special_tokens=True,clean_up_tokenization_spaces=False))labels = np.where(labels != -100, labels, tokenizer.pad_token_id)all_labels.extend(tokenizer.batch_decode(labels,skip_special_tokens=True,clean_up_tokenization_spaces=False))rouge1, rouge2, rougel, bleu4 = compute_metrics(all_preds, all_labels)model.train()return rouge1, rouge2, rougel, bleu4
def train(model, train_data_loader):global_step = 0best_rougel = 0tic_train = time.time()for epoch in range(num_epochs):for step, batch in enumerate(train_data_loader):global_step += 1# 模型前向训练,计算loss_, _, loss = model(**batch)loss.backward()optimizer.step()lr_scheduler.step()optimizer.clear_grad()if global_step % log_steps == 0:logger.info("global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s"% (global_step, num_training_steps, epoch, step,paddle.distributed.get_rank(), loss, optimizer.get_lr(),log_steps / (time.time() - tic_train)))log_writer.add_scalar("train_loss", loss.numpy(), global_step)tic_train = time.time()if global_step % eval_steps== 0 or global_step == num_training_steps:tic_eval = time.time()rouge1, rouge2, rougel, bleu4 = evaluate(model, dev_data_loader, tokenizer,min_target_length, max_target_length)logger.info("eval done total : %s s" % (time.time() - tic_eval))log_writer.add_scalar("eval_rouge1", rouge1, global_step)log_writer.add_scalar("eval_rouge2", rouge2, global_step)log_writer.add_scalar("eval_rougel", rougel, global_step)log_writer.add_scalar("eval_bleu4", bleu4, global_step)if best_rougel < rougel:best_rougel = rougelif paddle.distributed.get_rank() == 0:if not os.path.exists(output_dir):os.makedirs(output_dir)# Need better way to get inner model of DataParallelmodel_to_save = model._layers if isinstance(model, paddle.DataParallel) else modelmodel_to_save.save_pretrained(output_dir)tokenizer.save_pretrained(output_dir)
train(model, train_data_loader)

测试训练效果,因为数据都被加密过,所以你看的是乱七八糟的。

model2 = BartForConditionalGeneration.from_pretrained('checkpoints')
model2.eval()
tokenizer2 = BartTokenizer.from_pretrained('checkpoints')
text = '乘傲必彼被杯摆霸傲傲澳颁波奥奔爆报榜吧罢百暗宝诚病拆曹弹丹颁播伯笔壁暗八奥朝唱逼贝缠阐暗财半扮巴储词波瓣八奥爸把暗澳薄巴宾扮八奥版标暗白邦办堡逼贝。保坝坝败八隔爸把。'
infer(text, model2, tokenizer2)

http://www.lryc.cn/news/621502.html

相关文章:

  • JavaScript 流程控制语句详解
  • 稳定且高效:GSPO如何革新大型语言模型的强化学习训练?
  • SpringCloud -- Nacos详细介绍
  • 跨网络 SSH 访问:借助 cpolar 内网穿透服务实现手机远程管理 Linux
  • 搭建前端开发环境 安装nvm nodejs pnpm 配置环境变量
  • Spark03-RDD01-简介+常用的Transformation算子
  • SQL:生成日期序列(填补缺失的日期)
  • 完整技术栈分享:基于Hadoop+Spark的在线教育投融资大数据可视化分析系统
  • 【Docker】关于hub.docker.com,无法打开,国内使用dockers.xuanyuan.me搜索容器镜像、查看容器镜像的使用文档
  • 关于截屏时实现游戏暂停以及本地和上线不同步问题
  • Java研学-SpringCloud(四)
  • Flink Stream API 源码走读 - keyBy
  • 转换一个python项目到moonbit,碰到报错输出:编译器对workflow.mbt文件中的类方法要求不一致的类型注解,导致无法正常编译
  • Vue响应式系统在超大型应用中的性能瓶颈
  • 中年海尔,是时候押注新方向了
  • 训练大模型的前提:数据治理工程:从原始数据到高质量语料的系统化治理实践
  • 抽奖程序web程序
  • 小迪安全v2023学习笔记(六十二讲)—— PHP框架反序列化
  • 实战 AI8051U 音视频播放:USART-SPI→DMA-P2P→SPI+I2S 例程详解
  • Redis 实用型限流与延时队列:从 Lua 固定/滑动窗口到 Streams 消费组(含脚本与压测)
  • 大华相机RTSP无法正常拉流问题分析与解决
  • (Arxiv-2025)Stand-In:一种轻量化、即插即用的身份控制方法用于视频生成
  • openwrt增加自定义网页
  • 基于asp.net#C##VUE框架的独居老人物资配送系统的设计与实现#sql server#visual studio
  • 国内多光谱相机做得好的厂家有哪些?-多光谱相机品牌厂家
  • 8月4日实训考察:重庆五一职院走进成都国际影像产业园
  • Flink面试题及详细答案100道(1-20)- 基础概念与架构
  • 基于.net、C#、asp.net、vs的保护大自然网站的设计与实现
  • Vue3中的ref与reactive全面解析:如何正确选择响应式声明方式
  • java 策略模式 demo