当前位置: 首页 > news >正文

torch.matmul和torch.bmm区别

torch.matmul可用于4维数组的相乘,而torch.bmm只能用户3维数组的相乘,以/home/tiger/.local/lib/python3.9/site-packages/transformers/models/vit/modeling_vit.py中的ViTSelfAttention实现为例,在transpose_for_scores之前的shape是(batch_size, seq_len, all_head_size),然后在transpose_for_scores被转成了(batch_size, num_attention_heads, seq_len, attention_head_size)。这个4维数组只在最后2维上乘:

class ViTSelfAttention(nn.Module):def __init__(self, config: ViTConfig) -> None:super().__init__()if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):raise ValueError(f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "f"heads {config.num_attention_heads}.")self.num_attention_heads = config.num_attention_headsself.attention_head_size = int(config.hidden_size / config.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(new_x_shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:mixed_query_layer = self.query(hidden_states)# 在transpose_for_scores之前的shape是(batch_size, seq_len, all_head_size),然后在transpose_for_scores被转成了(batch_size, num_attention_heads, seq_len, attention_head_size)。这个4维数组只在最后2维上乘key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)import pdb; pdb.set_trace();# Take the dot product between "query" and "key" to get the raw attention scores.attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)# Normalize the attention scores to probabilities.attention_probs = nn.functional.softmax(attention_scores, dim=-1)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)# Mask heads if we want toif head_mask is not None:attention_probs = attention_probs * head_maskcontext_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(new_context_layer_shape)outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)return outputs

用torch.bmm也可以实现self_attention,参考 Bert Transformer细节总结

http://www.lryc.cn/news/287819.html

相关文章:

  • k8s学习(RKE+k8s+rancher2.x)成长系列之概念介绍(一)
  • PHP - Yii2 异步队列
  • leetcode560和为k的子数组
  • 【ProtoBuf】使用指南
  • Buffer Pool
  • jetson-inference----docker内运行分类任务
  • Python脚本之操作Redis Cluster【二】
  • 认识数学建模
  • 计算机工作原理解析和解剖(基础版)
  • 外网ssh远程连接服务器
  • 滴滴基于 Ray 的 XGBoost 大规模分布式训练实践
  • k8s从入门到实践
  • Qt5.12.0 与 VS2017 在 .pro文件转.vcxproj文件
  • 金蝶云星空 ServiceGateway RCE漏洞复现
  • 二叉树的最大深度[简单]
  • [Redis]不同系统间安装redis服务器
  • Unity之动画和角色控制
  • C语言库函数实现字符串转大小写
  • hcip----ospf
  • vue中如何写过滤器
  • c语言-文件的读写操作(下)
  • android学习笔记----SQLite数据库
  • 开发知识点-Flutter移动应用开发
  • 视频尺寸魔方:分层遮掩3D扩散模型在视频尺寸延展的应用
  • openssl3.2/test/certs - 061 - other@good.org not permitted by CA1
  • 如何实现无公网ip远程访问本地websocket服务端【内网穿透】
  • pip清华源怎么换回来
  • [Go]认识Beego框架
  • JWT登录
  • MySQL和Redis的事务有什么异同?