torch.matmul和torch.bmm区别
torch.matmul可用于4维数组的相乘,而torch.bmm只能用户3维数组的相乘,以/home/tiger/.local/lib/python3.9/site-packages/transformers/models/vit/modeling_vit.py中的ViTSelfAttention实现为例,在transpose_for_scores之前的shape是(batch_size, seq_len, all_head_size),然后在transpose_for_scores被转成了(batch_size, num_attention_heads, seq_len, attention_head_size)。这个4维数组只在最后2维上乘:
class ViTSelfAttention(nn.Module):def __init__(self, config: ViTConfig) -> None:super().__init__()if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):raise ValueError(f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "f"heads {config.num_attention_heads}.")self.num_attention_heads = config.num_attention_headsself.attention_head_size = int(config.hidden_size / config.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)self.dropout = nn.Dropout(config.attention_probs_dropout_prob)def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(new_x_shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:mixed_query_layer = self.query(hidden_states)# 在transpose_for_scores之前的shape是(batch_size, seq_len, all_head_size),然后在transpose_for_scores被转成了(batch_size, num_attention_heads, seq_len, attention_head_size)。这个4维数组只在最后2维上乘key_layer = self.transpose_for_scores(self.key(hidden_states))value_layer = self.transpose_for_scores(self.value(hidden_states))query_layer = self.transpose_for_scores(mixed_query_layer)import pdb; pdb.set_trace();# Take the dot product between "query" and "key" to get the raw attention scores.attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)# Normalize the attention scores to probabilities.attention_probs = nn.functional.softmax(attention_scores, dim=-1)# This is actually dropping out entire tokens to attend to, which might# seem a bit unusual, but is taken from the original Transformer paper.attention_probs = self.dropout(attention_probs)# Mask heads if we want toif head_mask is not None:attention_probs = attention_probs * head_maskcontext_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(new_context_layer_shape)outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)return outputs
用torch.bmm也可以实现self_attention,参考 Bert Transformer细节总结