Masterclass
While the scaled dot-product attention mechanism, discussed previously, allows a model to weigh the importance of different tokens when processing one specific token, it does so using a single set of learned Query (Q), Key (K), and Value (V) projections. This might limit the model's ability to capture diverse types of relationships or focus on different aspects of the input simultaneously. For instance, one attention pattern might be needed to capture syntactic dependencies, while another focuses on semantic similarity across longer distances.
Multi-Head Attention addresses this by running the scaled dot-product attention mechanism multiple times in parallel, each with its own learned linear projections. Each parallel run is called an "attention head". This allows the model to jointly attend to information from different representation subspaces at different positions.
Instead of performing a single attention function with dmodel-dimensional keys, values, and queries, Multi-Head Attention first linearly projects the queries, keys, and values h times using different, learned linear projections for each head. Let the input queries, keys, and values be matrices Q, K, and V (often these are the same tensor in self-attention layers). For each head i∈{1,...,h}, we compute:
headi=Attention(QWiQ,KWiK,VWiV)Where the projections are parameter matrices: WiQ∈Rdmodel×dk WiK∈Rdmodel×dk WiV∈Rdmodel×dv
The Attention function here is the Scaled Dot-Product Attention described in the previous section. Typically, the dimensions for each head are set to dk=dv=dmodel/h. This division ensures that the total computational cost is similar to that of single-head attention with full dimensionality.
After computing the attention output for each head in parallel, their outputs (each of dimension dv) are concatenated together:
Concat(head1,head2,...,headh)∈Rsequence_length×(h⋅dv)Since we chose dv=dmodel/h, the concatenated dimension is h⋅dv=dmodel. This concatenated output is then passed through a final linear projection, parameterized by WO∈Rhdv×dmodel (or Rdmodel×dmodel), to produce the final output of the Multi-Head Attention layer:
MultiHead(Q,K,V)=Concat(head1,...,headh)WOThe entire process can be visualized as follows:
Input Queries, Keys, and Values are linearly projected independently for each attention head. The outputs of the parallel scaled dot-product attention mechanisms are concatenated and then passed through a final linear projection.
Let's look at a simplified implementation sketch using PyTorch to highlight the key steps. We assume the input tensors query
, key
, and value
have shape (batch_size, seq_len, d_model)
.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, (
"d_model must be divisible by num_heads"
)
self.d_model = d_model
self.num_heads = num_heads
# Dimension of keys/queries per head
self.d_k = d_model // num_heads
# Linear layers for initial projections
# (Query, Key, Value for all heads)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
# Final linear layer after concatenation
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(
self, Q, K, V, mask=None
):
# Q, K, V shape: (batch_size, num_heads, seq_len, d_k)
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(
self.d_k
)
# attn_scores shape: (batch_size, num_heads, seq_len, seq_len)
if mask is not None:
# Apply mask (e.g., for padding or future tokens in decoder)
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_probs = F.softmax(attn_scores, dim=-1)
# attn_probs shape: (batch_size, num_heads, seq_len, seq_len)
output = torch.matmul(attn_probs, V)
# output shape: (batch_size, num_heads, seq_len, d_k)
return output
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. Perform linear projections
Q = self.W_q(query) # (batch_size, seq_len, d_model)
K = self.W_k(key) # (batch_size, seq_len, d_model)
V = self.W_v(value) # (batch_size, seq_len, d_model)
# 2. Reshape for multi-head attention
# (batch_size, seq_len, d_model) ->
# (batch_size, seq_len, num_heads, d_k) ->
# (batch_size, num_heads, seq_len, d_k)
Q = Q.view(
batch_size, -1, self.num_heads, self.d_k
).transpose(1, 2)
K = K.view(
batch_size, -1, self.num_heads, self.d_k
).transpose(1, 2)
V = V.view(
batch_size, -1, self.num_heads, self.d_k
).transpose(1, 2)
# 3. Apply scaled dot-product attention per head
attn_output = self.scaled_dot_product_attention(
Q, K, V, mask
)
# attn_output shape: (batch_size, num_heads, seq_len, d_k)
# 4. Concatenate heads and apply final linear layer
# Reshape back: (batch_size, num_heads, seq_len, d_k) ->
# (batch_size, seq_len, num_heads, d_k) ->
# (batch_size, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
output = self.W_o(attn_output) # (batch_size, seq_len, d_model)
return output
# Example usage
# d_model = 512
# num_heads = 8
# multihead_attn = MultiHeadAttention(d_model, num_heads)
# seq_len = 100
# batch_size = 32
# input_tensor = torch.randn(batch_size, seq_len, d_model) # Example input
# output = multihead_attn(
# input_tensor, input_tensor, input_tensor
# ) # Self-attention
# print(output.shape) # Should be torch.Size([32, 100, 512])
This sketch demonstrates how the input Q, K, V tensors are projected and then reshaped to allow parallel computation across heads. The transpose
operations are essential for grouping the head dimension alongside the batch dimension, enabling efficient batch matrix multiplications within the scaled_dot_product_attention
function. Finally, the outputs are reshaped back and passed through the final output projection WO.
Using multiple attention heads offers several advantages:
Multi-Head Attention is a fundamental component not just in the original Transformer but in nearly all subsequent large language models. It provides a powerful and computationally manageable way to enhance the basic attention mechanism. The interaction between these heads, residual connections, and normalization layers (discussed next) forms the core of the Transformer block.
© 2025 ApX Machine Learning