当前位置: 首页 > news >正文

NLP:Transformer模型构建

本文目录:

  • 一、编码器和解码器的代码实现
  • 二、实例化编码器解码器函数
  • 三、代码运行结果

前言:前面讲解了Transformer的各个部分,本文讲解Transformer模型整体构建。

简单来说,Transformer标准结构包括6个编码器和6个解码器,另外包括1个输入层和1个输出层

一、编码器和解码器的代码实现

# 定义EncoderDecoder类
class EncoderDecoder(nn.Module):def __init__(self, encoder, decoder, source_embed, target_embed, generator):super().__init__()# encoder:编码器对象self.encoder = encoder# decoder:解码器的对象self.decoder = decoder# source_embed:源语言输入部分的对象:wordEmbedding+PositionEncodingself.source_embed = source_embed# target_embed:目标语言输入部分的对象:wordEmbedding+PositionEncodingself.target_embed = target_embed# generator:输出层对象self.generator = generatordef forward(self, source, target, source_mask1, source_mask2, target_mask):# source:源语言的输入,形状--[batch_size, seq_len]-->[2, 4]# target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6]# source_mask1:padding mask:作用在编码器端多头自注意力机制-->[head, source_seq_len, source_seq_len]-->[8, 4, 4]# source_mask2:padding mask:作用在解码器端多头注意力机制-->[head, target_seq_len, source_seq_len]-->[8, 6, 4]# target_mask:sentence mask:作用在解码器端多头自注意力机制-->[head, target_seq_len, target_seq_len]-->[8, 6, 6]# 1.将原始的source源语言的输入,形状--[batch_size, seq_len]-->[2, 4]送入编码器输入部分变成--[2,4,512]# encode_word_embed:wordEmbedding+PositionEncodingencode_word_embed = self.source_embed(source)# 2. encode_word_embed以及source_mask1送入编码器得到编码之后的结果:encoder_output-->[2, 4, 512]encoder_output = self.encoder(encode_word_embed, source_mask1)# 3. target:目标语言的输入,形状--[batch_size, seq_len]-->[2, 6] 送入解码器输入部分变成--[2,6,512]decode_word_embed = self.target_embed(target)# 4. 将decode_word_embed,encoder_output,source_mask2,target_mask送入解码器decoder_output = self.decoder(decode_word_embed, encoder_output, source_mask2, target_mask)# 5.将decoder_output送入输出层output = self.generator(decoder_output)return output

二、实例化编码器解码器函数

def dm_transformer():# 1.实例化编码器对象# 实例化多头注意力机制的对象mha = MutiHeadAttention(embed_dim=512, head=8, dropout_p=0.1)# 实例化前馈全连接层对象ff = FeedForward(d_model=512, d_ff=1024)encoder_layer = EncoderLayer(size=512, self_atten=mha, ff=ff, dropout_p=0.1)encoder = Encoder(layer=encoder_layer, N=6)# 2.实例化解码器对象self_attn = copy.deepcopy(mha)src_attn = copy.deepcopy(mha)feed_forward = copy.deepcopy(ff)decoder_layer = DecoderLayer(size=512, self_attn=self_attn, src_attn=src_attn, feed_forward=feed_forward, dropout_p=0.1)decoder = Decoder(layer=decoder_layer, N=6)# 3.源语言输入部分的对象:wordEmbedding+PositionEncoding# 经过Embedding层vocab_size = 1000d_model = 512encoder_embed = Embeddings(vocab_size=vocab_size, d_model=d_model)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)dropout_p = 0.1encoder_pe = PositionEncoding(d_model=d_model, dropout_p=dropout_p)source_embed = nn.Sequential(encoder_embed, encoder_pe)# 4.目标语言输入部分的对象:wordEmbedding+PostionEncoding# 经过Embedding层decoder_embed = copy.deepcopy(encoder_embed)# 经过位置编码器层(在位置编码器内部,我们其实已经融合来embed_x)decoder_pe = copy.deepcopy(encoder_pe)target_embed = nn.Sequential(decoder_embed, decoder_pe)# 5.实例化输出对象generator = Generator(d_model=512, vocab_size=2000)# 6.实例化EncoderDecoder对象transformer = EncoderDecoder(encoder, decoder, source_embed, target_embed, generator)print(transformer)# 7.准备数据source = torch.tensor([[1, 2, 3, 4],[2, 5, 6, 10]])target = torch.tensor([[1, 20, 3, 4, 19, 30],[21, 5, 6, 10, 80,38]])source_mask1 = torch.zeros(8, 4, 4)source_mask2 = torch.zeros(8, 6, 4)target_mask = torch.zeros(8, 6, 6)result = transformer(source, target, source_mask1, source_mask2, target_mask)print(f'transformer模型最终的输出结果--》{result}')print(f'transformer模型最终的输出结果--{result.shape}')

三、代码运行结果

# 根据Transformer结构图构建的最终模型结构
EncoderDecoder((encoder): Encoder((layers): ModuleList((0): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): EncoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(decoder): Decoder((layers): ModuleList((0): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))))(1): DecoderLayer((self_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(src_attn): MultiHeadedAttention((linears): ModuleList((0): Linear(in_features=512, out_features=512)(1): Linear(in_features=512, out_features=512)(2): Linear(in_features=512, out_features=512)(3): Linear(in_features=512, out_features=512))(dropout): Dropout(p=0.1))(feed_forward): PositionwiseFeedForward((w_1): Linear(in_features=512, out_features=2048)(w_2): Linear(in_features=2048, out_features=512)(dropout): Dropout(p=0.1))(sublayer): ModuleList((0): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(1): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1))(2): SublayerConnection((norm): LayerNorm()(dropout): Dropout(p=0.1)))))(norm): LayerNorm())(src_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(tgt_embed): Sequential((0): Embeddings((lut): Embedding(11, 512))(1): PositionalEncoding((dropout): Dropout(p=0.1)))(generator): Generator((proj): Linear(in_features=512, out_features=11))
)

如果代码有不懂,可参看此前文章,谢谢阅读,今天分享结束。

http://www.lryc.cn/news/623319.html

相关文章:

  • 【遥感图像技术系列】遥感图像风格迁移的研究进展一览
  • Win10快速安装.NET3.5
  • 排列与组合
  • React单元测试
  • 云安全 - The Big IAM Challenge
  • XSS攻击:从原理入门到实战精通详解
  • JCTools 无锁并发队列基础:ConcurrentCircularArrayQueue
  • 深入解析C++ STL链表(List)模拟实现
  • 如何得知是Counter.razor通过HTTP回调处理的还是WASM处理的,怎么检测?
  • 基于Python的电影评论数据分析系统 Python+Django+Vue.js
  • qt vs2019编译QXlsx
  • Qt QDateTime时间部分显示为全0,QTime赋值后显示无效问题【已解决】
  • ML307C 4G通信板:工业级DTU固件,多协议支持,智能配置管理
  • 随机整数列表处理:偶数索引降序排序
  • 数据库索引视角:对比二叉树到红黑树再到B树
  • 《探索IndexedDB实现浏览器端UTXO模型的前沿技术》
  • 使用影刀RPA实现快递信息抓取
  • C++ 最短路Dijkstra
  • 9.从零开始写LINUX内核——设置中断描述符表
  • Python 类元编程(元类的特殊方法 __prepare__)
  • Flink Stream API 源码走读 - 总结
  • 楼宇自控系统赋能建筑全维度管理,实现环境、安全与能耗全面监管
  • STM32硬件SPI配置为全双工模式下不要单独使用HAL_SPI_Transmit API及HAL_SPI_TransmitReceive改造方法
  • 【时时三省】(C语言基础)共用体类型数据的特点
  • Langfuse2.60.3:独立数据库+docker部署及环境变量详细说明
  • Java 中重载与重写的全面解析(更新版)
  • Mybatis-3自己实现MyBatis底层机制
  • 从冒泡到快速排序:探索经典排序算法的奥秘(二)
  • PHP反序列化的CTF题目环境和做题复现第1集
  • 企业运维规划及Linux介绍虚拟环境搭建