当前位置: 首页 > news >正文

昇思MindSpore学习总结十六 —— 基于MindSpore的GPT2文本摘要

1、mindnlp 版本要求

!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp

2、数据集加载与处理

2.1 数据集加载

 本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindspore.dataset import TextFileDataset  # 从mindspore.dataset模块中导入TextFileDataset类# load dataset  # 加载数据集
dataset = TextFileDataset(str(path), shuffle=False)  # 创建一个TextFileDataset实例,参数是文件路径(path)转换成字符串格式,shuffle=False表示不打乱数据顺序
dataset.get_dataset_size()  # 获取数据集的大小,即数据集中样本的数量

# split into training and testing dataset  # 将数据集分割为训练集和测试集
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)  # 将数据集按比例[0.9, 0.1]分割为训练集和测试集,randomize=False表示不随机打乱数据

 2.2 数据预处理

import json  # 导入json模块,用于处理JSON数据
import numpy as np  # 导入numpy模块,并简写为np,用于处理数组和矩阵# preprocess dataset  # 预处理数据集
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数merge_and_pad,用于合并并填充数据def merge_and_pad(article, summary):# tokenization  # 进行分词操作# pad to max_seq_length, only truncate the article  # 填充到最大序列长度,仅截断文章部分tokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)  # 使用tokenizer对文章和摘要进行分词,填充到最大长度,仅截断文章部分return tokenized['input_ids'], tokenized['input_ids']  # 返回分词后的输入ID(注意:这里的input_ids和labels是相同的)dataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要# change column names to input_ids and labels for the following training  # 更改列名为input_ids和labels,以便后续训练dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])  # 使用merge_and_pad函数对数据进行映射,生成input_ids和labelsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理if shuffle:dataset = dataset.shuffle(batch_size)  # 如果shuffle为True,则对批次进行随机打乱return dataset  # 返回预处理后的数据集

 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。

from mindnlp.transformers import BertTokenizer  # 从mindnlp.transformers模块中导入BertTokenizer类# We use BertTokenizer for tokenizing Chinese context.  # 我们使用BertTokenizer对中文内容进行分词
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')  # 使用预训练的'bert-base-chinese'模型初始化BertTokenizer
len(tokenizer)  # 获取tokenizer的词汇表大小

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)  # 使用process_dataset函数对训练数据集进行预处理,传入参数包括训练数据集、分词器和批次大小为4
next(train_dataset.create_tuple_iterator())  # 创建一个tuple迭代器并获取其第一个元素

 3、模型构建

3.1 构建GPT2ForSummarization模型,注意shift right的操作。

from mindspore import ops  # 从mindspore模块导入ops操作
from mindnlp.transformers import GPT2LMHeadModel  # 从mindnlp.transformers模块中导入GPT2LMHeadModel类# 定义一个用于摘要生成的GPT2模型类,继承自GPT2LMHeadModel
class GPT2ForSummarization(GPT2LMHeadModel):# 定义模型的构造函数def construct(self,input_ids=None,  # 输入IDattention_mask=None,  # 注意力掩码labels=None,  # 标签):# 调用父类的construct方法,获取模型输出outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]  # 移动logits,使其与shift_labels对齐shift_labels = labels[..., 1:]  # 移动标签,使其与shift_logits对齐# Flatten the tokens  # 将tokens展平loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)  # 计算交叉熵损失,忽略填充的tokenreturn loss  # 返回计算的损失

 3.2 动态学习率

from mindspore import ops  # 从mindspore模块导入ops操作
from mindspore.nn.learning_rate_schedule import LearningRateSchedule  # 从mindspore.nn.learning_rate_schedule模块导入LearningRateSchedule类# 定义一个线性学习率衰减与热身相结合的学习率调度类,继承自LearningRateSchedule
class LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate.  # 热身-衰减学习率。"""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()  # 调用父类的构造函数self.learning_rate = learning_rate  # 初始化学习率self.num_warmup_steps = num_warmup_steps  # 初始化热身步数self.num_training_steps = num_training_steps  # 初始化训练步数# 定义构造函数def construct(self, global_step):# 如果当前步数小于热身步数if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate  # 线性增加学习率# 否则,学习率进行线性衰减return ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate  # 计算并返回衰减后的学习率

 4、模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
from mindspore import nn  # 从mindspore模块导入nn(神经网络)模块
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel  # 从mindnlp.transformers模块导入GPT2Config和GPT2LMHeadModel类# 配置GPT2模型的配置
config = GPT2Config(vocab_size=len(tokenizer))  # 创建GPT2配置实例,设置词汇表大小为tokenizer的长度
model = GPT2ForSummarization(config)  # 使用配置实例创建一个GPT2ForSummarization模型# 创建学习率调度器
lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)  # 创建线性热身-衰减学习率调度器# 创建优化器
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)  # 使用AdamWeightDecay优化器,并传入模型的可训练参数和学习率调度器
# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))

from mindnlp._legacy.engine import Trainer  # 从mindnlp._legacy.engine模块导入Trainer类
from mindnlp._legacy.engine.callbacks import CheckpointCallback  # 从mindnlp._legacy.engine.callbacks模块导入CheckpointCallback类# 创建一个CheckpointCallback实例,用于保存检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint',  # 检查点保存路径ckpt_name='gpt2_summarization',  # 检查点文件名epochs=1,  # 每个epoch保存一次检查点keep_checkpoint_max=2  # 最多保留两个检查点
)# 创建一个Trainer实例,用于训练模型
trainer = Trainer(network=model,  # 要训练的模型train_dataset=train_dataset,  # 训练数据集epochs=1,  # 训练的epoch数optimizer=optimizer,  # 优化器callbacks=ckpoint_cb  # 回调函数,包括检查点回调
)trainer.set_amp(level='O1')  # 开启混合精度训练,级别设置为'O1'

下面这段代码,运行时间较长,最好选择较高算力。 

trainer.run(tgt_columns="labels")  # 运行训练器,指定目标列为“labels”

配置不够,训练时间太长。 

5、模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):# 定义一个嵌套函数read_map,用于读取并解析JSON文本数据def read_map(text):data = json.loads(text.tobytes())  # 将文本数据转换为字节后用json.loads解析为Python字典return np.array(data['article']), np.array(data['summarization'])  # 返回文章和摘要的numpy数组# 定义一个嵌套函数pad,用于对文章进行分词和填充def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)  # 对文章进行分词,截断至最大长度减去摘要长度return tokenized['input_ids']  # 返回分词后的输入IDdataset = dataset.map(read_map, 'text', ['article', 'summary'])  # 使用read_map函数对数据集进行映射,提取文章和摘要dataset = dataset.map(pad, 'article', ['input_ids'])  # 使用pad函数对文章进行分词和填充,生成input_idsdataset = dataset.batch(batch_size)  # 将数据集按批次大小进行分批处理return dataset  # 返回预处理后的数据集
test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
# 创建一个tuple迭代器并获取其第一个元素,以NumPy数组的形式输出,并打印出来
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)  # 从预训练的检查点加载模型
model.set_train(False)  # 设置模型为评估模式(非训练模式)
model.config.eos_token_id = model.config.sep_token_id  # 设置模型的eos_token_id为sep_token_id
i = 0  # 初始化计数器为0# 遍历测试数据集的迭代器,获取输入ID和原始摘要
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():# 使用模型生成新的摘要,参数包括最大新生成的token数量、束搜索的束数、不重复的ngram大小output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)# 将生成的ID转换为文本output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)  # 打印生成的摘要文本i += 1  # 计数器加1if i == 1:  # 如果计数器达到1break  # 跳出循环,仅生成并打印一个摘要

http://www.lryc.cn/news/404114.html

相关文章:

  • React Router 6笔记
  • Android init 中的wait_for_property指令
  • 智能合约语言(eDSL)—— 并行化方案——调度算法
  • vue2.0中如何实现数据监听
  • kafka开启kerberos和ACL
  • QT+winodow 代码适配调试总结(三)
  • Linux之旅:常用的指令,热键和权限管理
  • 简单实用的企业舆情安全解决方案
  • 【中项】系统集成项目管理工程师-第2章 信息技术发展-2.1信息技术及其发展-2.1.1计算机软硬件与2.1.2计算机网络
  • SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表
  • ubuntu 上安装中文输入法
  • Postman导出excel文件
  • 你还在手动构建Python项目吗?PyBuilder让一切自动化!
  • WebRTC音视频-前言介绍
  • centos/rocky容器中安装xfce、xrdp记录
  • 实战:Eureka的概念作用以及用法详解
  • jupyter_contrib_nbextensions安装失败问题
  • 设计模式-Git-其他
  • 【C#】计算两条直线的交点坐标
  • 在项目服务器部署git 并实现自动提交
  • 前缀匹配工具之IP-Prefix
  • 等级保护测评案例分享及合规建议
  • GOLLIE : ANNOTATION GUIDELINES IMPROVE ZERO-SHOT INFORMATION-EXTRACTION
  • 2024-07-19 Unity插件 Odin Inspector9 —— Validation Attributes
  • 跨平台WPF音乐商店应用程序
  • 设计模式简述(一)
  • OSI参考模型:解析网络通信的七层框架
  • QT通用配置文件库(QPreferences)
  • 如何搭建一个RADIUS服务器?
  • 双机热备综合实验