当前位置: 首页 > news >正文

AIGC笔记--条件自回归Transformer的搭建

1--概述

        1. 自回归 TransFormer 规定Token只能看到自身及前面的Token,因此需生成一个符合规定的Attention Mask;(代码提供了两种方式自回归Attention Mask的定义方式);

        2. 使用Cross Attention实现条件模态和输入模态之间的模态融合,输入模态作为Query,条件模态作为Key和Value;

2--代码

import torch
import torch.nn as nnclass CrossAttention(nn.Module):def __init__(self, embed_dim: int, num_heads: int):super().__init__()self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)def forward(self, input_x: torch.Tensor, condition: torch.Tensor, attn_mask: torch.Tensor = None):'''query: input_xkey: conditionval: condition'''input_x = self.cross_attn(input_x, condition, condition, attn_mask=attn_mask)[0]return input_xclass Cond_Autoregressive_layer(nn.Module):def __init__(self, input_dim: int, condtion_dim: int, embed_dim: int, num_heads: int):super(Cond_Autoregressive_layer, self).__init__()self.linear1 = nn.Linear(input_dim, embed_dim)self.linear2 = nn.Linear(condtion_dim, embed_dim)self.cond_multihead_attn = CrossAttention(embed_dim = embed_dim, num_heads = num_heads)def forward(self, input_x: torch.Tensor, conditon: torch.Tensor, attention_mask1: torch.Tensor, attention_mask2: torch.Tensor):# q, k, v, attention mask, here we set key and value are both condtion y1 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask1)y2 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask2)return y1, y2if __name__ == "__main__":# set sequence len, embedding dim, multi attention headseq_length = 10input_dim = 32condtion_dim = 128embed_dim = 64num_heads = 8# init input sequence and condtioninput_x = torch.randn(seq_length, 1, input_dim)condtion = torch.randn(seq_length, 1, condtion_dim)# create two attention mask (actually they have the same function)attention_mask1 = torch.triu((torch.ones((seq_length, seq_length)) == 1), diagonal=1) # bool typeattention_mask2 = attention_mask1.float() # True->1 False->0attention_mask2 = attention_mask2.masked_fill(attention_mask2 == 1, float("-inf"))  # Convert ones to -inf# init modelAG_layer = Cond_Autoregressive_layer(input_dim, condtion_dim, embed_dim, num_heads)# forwardy1, y2 = AG_layer(input_x, condtion, attention_mask1, attention_mask2)# here we demonstrate the attention_mask1 and attention_mask2 have the same functionassert(y1[0].equal(y2[0]))

http://www.lryc.cn/news/312195.html

相关文章:

  • 数据结构->链表分类与oj(题),带你提升代码好感
  • unity-unity2d基础操作笔记(三)0.5.000
  • 【精华】AIGC启元2024
  • js对象解构语法
  • flowable使用taskService.addComment新增评论需要full_msg字段进行读取
  • java常用技术栈,java面试带答案
  • 刷题第11天
  • QML中动态增加表格数据
  • OBS插件开发(二)推流实时曲线
  • Linux编程3.3 进程-进程的终止
  • 排序(3)——直接选择排序
  • [LeetBook]【学习日记】数组内重组
  • 【Linux】磁盘情况、挂载,df -h无法看到的卷
  • AIOps实践中常见的挑战:故障根因与可观测性数据的割裂
  • python 远程代码第一次推送
  • C++开发基础之简单的计时器也有适配场景
  • 数电学习笔记——逻辑函数及其描述方法
  • 2024年护眼台灯哪家品牌好?五款优质品牌专业推荐
  • 搜索iconfont或者阿里图标就可以得到免费的图标
  • android实战视频教程,细数Android开发者的艰辛历程
  • nav2_gps_waypoint_follower_demo 不能在ros2 humble中直接使用的解决方法
  • 华为OD机试 - 螺旋数字矩阵
  • Vue响应式内容丢失处理
  • Linux安装Rabbitmq
  • 在nginx 服务器部署vue项目
  • 制作一个简单的HTML个人网页
  • HM2019创建载荷工况
  • Effective C++ 学习笔记 条款14 在资源管理类中小心copying行为
  • c++数据结构算法复习基础-- 3 --线性表-单向链表-笔试面试常见问题
  • 【踩坑专栏】追根溯源,从Linux磁盘爆满排查故障:mycat2与navicat不兼容导致日志暴增