Masterclass
While the scaled dot-product attention mechanism we implemented previously allows the model to focus on different parts of the sequence, the "Attention Is All You Need" paper introduced Multi-Head Attention to further enhance this capability. Instead of performing a single attention function with dmodel​-dimensional keys, values, and queries, multi-head attention involves projecting the queries, keys, and values h times with different, learned linear projections to dk​, dk​, and dv​ dimensions, respectively. Attention is then performed in parallel for each of these projected versions. The outputs are concatenated and once again projected, resulting in the final values.
The intuition is that each "head" can learn to attend to different types of information or relationships within the sequence simultaneously. For instance, one head might focus on syntactic dependencies, while another tracks co-reference relationships. By running these attention mechanisms in parallel and combining their outputs, the model gains a richer, multi-faceted understanding of the input.
Mathematically, Multi-Head Attention is defined as:
MultiHead(Q,K,V)=Concat(head1​,…,headh​)WOwhere each headi​ is calculated as:
headi​=Attention(QWiQ​,KWiK​,VWiV​)Here, Q,K,V are the input query, key, and value matrices. The projection matrices WiQ​∈Rdmodel​×dk​, WiK​∈Rdmodel​×dk​, and WiV​∈Rdmodel​×dv​ are parameter matrices for the i-th head, and WO∈Rhdv​×dmodel​ is the output projection matrix. In the original Transformer paper and many common implementations, the dimensions are set such that dk​=dv​=dmodel​/h. This keeps the computational cost similar to single-head attention with dmodel​-dimensional keys and values.
Let's implement this in PyTorch. We'll create a MultiHeadAttention
module that takes the embedding dimension (embed_dim
), the number of heads (num_heads
), and an optional dropout probability as input.
import torch
import torch.nn as nn
import math
# Assume scaled_dot_product_attention is defined as in the previous section
# def scaled_dot_product_attention(q, k, v, mask=None):
# d_k = q.size(-1)
# scores = torch.matmul(q, k.transpose(-2, -1)) / \
# math.sqrt(d_k)
# if mask is not None:
# # Use a large negative value
# scores = scores.masked_fill(mask == 0, -1e9)
# attn_weights = torch.softmax(scores, dim=-1)
# output = torch.matmul(attn_weights, v)
# return output, attn_weights
class MultiHeadAttention(nn.Module):
""" Implements the Multi-Head Attention mechanism. """
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, (
"Embedding dimension must be divisible by number of heads")
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# Linear layers for Q, K, V projections.
# We use a single linear layer for efficiency,
# projecting to embed_dim * 3 and then splitting the result.
# Alternatively, separate layers can be used.
self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
# Output projection layer
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout) # Dropout layer (optional)
self._reset_parameters()
def _reset_parameters(self):
# Use Xavier uniform initialization for linear layers
nn.init.xavier_uniform_(self.qkv_proj.weight)
self.qkv_proj.bias.data.fill_(0)
nn.init.xavier_uniform_(self.out_proj.weight)
self.out_proj.bias.data.fill_(0)
def forward(self, query, key, value, mask=None):
"""
Forward pass for Multi-Head Attention.
Args:
query (torch.Tensor): Query tensor,
shape (batch_size, seq_len_q, embed_dim)
(torch.Tensor): Tensor,
shape (batch_size, seq_len_k, embed_dim)
value (torch.Tensor): Value tensor,
shape (batch_size, seq_len_v, embed_dim)
Note: seq_len_k == seq_len_v usually.
mask (torch.Tensor, optional): Mask tensor to prevent
attention to certain positions.
Shape (batch_size, 1, seq_len_q,
seq_len_k) or similar
broadcastable shape.
Returns:
torch.Tensor: Output tensor,
shape (batch_size, seq_len_q, embed_dim)
torch.Tensor: Attention weights,
shape (batch_size, num_heads, seq_len_q, seq_len_k)
"""
batch_size, seq_len_q, _ = query.size()
# Value sequence lengths must match
_, seq_len_k, _ = key.size()
_, seq_len_v, _ = value.size()
assert seq_len_k == seq_len_v
# 1. Project Q, K, V using the combined linear layer
qkv = self.qkv_proj(query) # Project query
# We project and value separately in case
# they have different source lengths in encoder-decoder attention,
# although here we assume self-attention (query=key=value)
# For generality, let's assume separate inputs for k, v could exist.
# If query, value are the same tensor (self-attention),
# this is slightly less efficient than projecting once and splitting,
# but more flexible.
k_proj = self.qkv_proj(key) # Project key
v_proj = self.qkv_proj(value) # Project value
# Split the combined projection into Q, K, V
# qkv shape: (batch_size, seq_len, embed_dim * 3) ->
# 3 tensors of (batch_size, seq_len, embed_dim)
q, k, v = qkv.chunk(3, dim=-1)
# Alternative if using separate layers or just projecting query differently:
# q = self.q_proj(query)
# k = self.k_proj(key)
# v = self.v_proj(value)
# 2. Reshape Q, K, V for multi-head computation
# Reshape from (batch_size, seq_len, embed_dim) to
# (batch_size, num_heads, seq_len, head_dim)
q = q.view(batch_size,
seq_len_q,
self.num_heads,
self.head_dim).transpose(1, 2)
k = k.view(batch_size,
seq_len_k,
self.num_heads,
self.head_dim).transpose(1, 2)
v = v.view(batch_size,
seq_len_v,
self.num_heads,
self.head_dim).transpose(1, 2)
# 3. Apply scaled dot-product attention for each head
# The mask needs to be correctly broadcastable.
# If mask is (batch_size, seq_len_q, seq_len_k), it needs to be
# unsqueezed for the head dimension:
# (batch_size, 1, seq_len_q, seq_len_k)
if mask is not None:
# (batch_size, seq_len_q, seq_len_k)
if mask.dim() == 3:
# Add head dimension: (batch_size, 1, seq_len_q, seq_len_k)
mask = mask.unsqueeze(1)
# (seq_len_q, seq_len_k) - same mask for all batches
elif mask.dim() == 2:
# Add batch and head: (1, 1, seq_len_q, seq_len_k)
mask = mask.unsqueeze(0).unsqueeze(0)
# Ensure mask shape is compatible:
# (batch_size, num_heads, seq_len_q, seq_len_k) or broadcastable
# attn_output shape: (batch_size, num_heads, seq_len_q, head_dim)
# attn_weights shape: (batch_size, num_heads, seq_len_q, seq_len_k)
attn_output, attn_weights = scaled_dot_product_attention(
q, k, v, mask=mask
)
# 4. Concatenate heads and project back to embed_dim
# Transpose and reshape to combine heads:
# (batch_size, seq_len_q, num_heads * head_dim)
# num_heads * head_dim = embed_dim
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len_q, self.embed_dim
)
# Apply final linear projection
output = self.out_proj(attn_output)
# Apply dropout (optional)
output = self.dropout(output)
return output, attn_weights
In this implementation:
embed_dim
, num_heads
, and calculate head_dim
. It's important that embed_dim
is divisible by num_heads
.nn.Linear
layer (qkv_proj
) to project the input query, key, and value tensors simultaneously for efficiency. We then split the result into q
, k
, and v
. An alternative is to define separate nn.Linear
layers for q
, k
, and v
._reset_parameters
method handles weight initialization, using Xavier uniform initialization, a common practice for Transformer layers.forward
method:
qkv_proj
.q
, k
, v
are reshaped to separate the heads. The dimensions become (batch_size, num_heads, seq_len, head_dim)
. The transpose(1, 2)
operation rearranges the dimensions so that the head dimension comes before the sequence length dimension, which is typically expected by attention implementations or optimized kernels.scaled_dot_product_attention
function (which we assume is defined elsewhere, perhaps in the previous section's code or a utility file) on the reshaped q
, k
, v
and the optional mask
. The mask handling ensures it broadcasts correctly across the heads.attn_output
) are concatenated back together. We achieve this by first transposing back (transpose(1, 2)
) and then using contiguous().view()
to merge the num_heads
and head_dim
dimensions into the original embed_dim
. The contiguous()
call is necessary because transpose
can return a non-contiguous tensor, which view
cannot operate on directly.out_proj
) and an optional dropout layer.This MultiHeadAttention
module encapsulates the core logic described in the Transformer paper. It takes query, key, and value inputs (which are often the same tensor in self-attention layers) and produces an output tensor of the same shape as the query, along with the attention weights for potential analysis. This module will be a building block for the larger Encoder and Decoder layers we construct next.
Flow diagram illustrating the steps within the Multi-Head Attention module, starting from input Query, Key, Value tensors to the final output and attention weights.
h
denotes the number of heads.bs
stands for batch size.
Was this section helpful?
© 2025 ApX Machine Learning