当前位置: 首页 > news >正文

Transformer详解(5)-编码器和解码器

1、Transformer编码器

import torch
from torch import nn
import copy
from norm import Norm
from multi_head_attention import MultiHeadAttention
from feed_forward import FeedForward
from pos_encoder import PositionalEncoderdef get_clones(module, N):"""Create N identical layers.Args:module: The module (layer) to be duplicated.N: The number of copies to create.Returns:A ModuleList containing N identical copies of the module."""return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])# transformer块
class EncoderLayer(nn.Module):def __init__(self, d_model=512, d_ff=2048, heads=8, dropout=0.1):super(EncoderLayer, self).__init__()self.norm_1 = Norm(d_model)self.norm_2 = Norm(d_model)self.attn = MultiHeadAttention(heads, d_model, dropout)self.ff = FeedForward(d_model, d_ff)self.dropout_1 = nn.Dropout(dropout)self.dropout_2 = nn.Dropout(dropout)def forward(self, x, mask):attn_ouput = self.attn(x, x, x, mask)attn_ouput = self.dropout_1(attn_ouput)x = x + attn_ouput  # 残差连接x = self.norm_1(x)  # 层归一化ff_output = self.ff(x)  # 前馈层ff_output = self.dropout_2(ff_output)x = x + ff_output  # 残差连接x = self.norm_2(x)  # 层归一化return xclass TransformerEncoder(nn.Module):def __init__(self, vocab_size=1000, max_seq_len=50, d_model=512, d_ff=2048, N=6, heads=8, dropout=0.1):super(TransformerEncoder, self).__init__()'''vocab_size  词典大小max_seq_len  序列最大长度d_model  词嵌入大小d_ff  前馈层隐层维度N  编码器中transformer的个数heads  多头个数dropout  dropout比例'''self.N = Nself.embed = nn.Embedding(vocab_size, d_model)self.pe = PositionalEncoder(max_seq_len, d_model)self.layers = get_clones(EncoderLayer(d_model, d_ff, heads, dropout), N)self.norm = Norm(d_model)def forward(self, src, mask=None):x = self.embed(src)  # embeddingx = self.pe(x)  # 位置编码for i in range(self.N):x = self.layers[i](x, mask)output = self.norm(x)return outputif __name__ == '__main__':# Parameterslength = 50low = 0high = 1001  # The upper bound is exclusive in torch.randint# Generate random integersrandom_tensor = torch.randint(low=low, high=high, size=(length,))vocab_size = 1000max_seq_len = 50d_model = 512d_ff = 2048heads = 8N = 2dropout = 0.1trans_encoder = TransformerEncoder(vocab_size, max_seq_len, d_model, d_ff, N, heads, dropout)output = trans_encoder(random_tensor)print(output.shape)  # torch.Size([1, 50, 512])
http://www.lryc.cn/news/355079.html

相关文章:

  • 线程安全-3 JMM
  • 4 CSS的 变换、过渡与动画
  • 前端基础入门三大核心之JS篇:掌握数字魔法 ——「累加器与累乘器」的奥秘籍【含样例代码】
  • git clone 出现的问题
  • Vue2和Vue3生命周期的对比
  • 全面解析Java.lang.ClassCastException异常
  • 美团Java社招面试题真题,最新面试题
  • 二十八、openlayers官网示例Data Tiles解析——自定义绘制DataTile源数据
  • 分布式事务解决方案(最终一致性【TCC解决方案】)
  • App Inventor 2 Encrypt.Security 安全性扩展:MD5哈希,SHA/AES/RSA/BASE64
  • 深入了解Linux中的环境变量
  • 雷军-2022.8小米创业思考-8-和用户交朋友,非粉丝经济;性价比是最大的诚意;新媒体,直播离用户更近;用真诚打动朋友,脸皮厚点!
  • 【Vue2.x】props技术详解
  • C语言例题46、根据公式π/4=1-1/3+1/5-1/7+1/9-1/11+…,计算π的近似值,当最后一项的绝对值小于0.000001为止
  • fpga系列 HDL: 05 阻塞赋值(=)与非阻塞赋值(<=)
  • 大白话DC3算法
  • 力扣HOT100 - 75. 颜色分类
  • Vue.js - 计算属性与侦听器 【0基础向 Vue 基础学习】
  • 技术速递|使用 C# 集合表达式重构代码
  • 我的世界开服保姆级教程
  • [转载]同一台电脑同时使用GitHub和GitLab
  • 【网络协议】【OSI】一次HTTP请求OSI工作过程详细解析
  • springboot vue 开源 会员收银系统 (2) 搭建基础框架
  • Java进阶学习笔记26——包装类
  • 【JavaEE进阶】——要想代码不写死,必须得有spring配置(properties和yml配置文件)
  • 第十四 Elasticsearch介绍和安装
  • YOLOv10介绍与推理--图片和视频演示(附源码)
  • Java实验08
  • MyBatis复习笔记
  • HTML的基石:区块标签与小语义标签的深度解析