Now that we have explored the theoretical underpinnings of multi-head self-attention, let's translate these concepts into a practical implementation using PyTorch. Building this layer from scratch solidifies understanding of the data flow and tensor manipulations involved. This hands-on exercise assumes familiarity with basic PyTorch modules and tensor operations.
Our goal is to create a MultiHeadAttention
module that takes an input sequence, projects it into Queries (Q), Keys (K), and Values (V) for multiple heads, computes scaled dot-product attention for each head in parallel, concatenates the results, and applies a final linear projection.
We'll define a Python class inheriting from torch.nn.Module
. The constructor (__init__
) will initialize the necessary linear layers for the initial Q, K, V projections and the final output projection. We also need to store the embedding dimension (d_model
) and the number of attention heads (num_heads
). An important constraint is that d_model
must be divisible by num_heads
so that the projected dimensions (dk, dv) are integers.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
"""Implements the Multi-Head Attention mechanism."""
def __init__(self, d_model: int, num_heads: int):
"""
Args:
d_model (int): The dimensionality of the input and output embeddings.
num_heads (int): The number of attention heads.
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension of keys/queries per head
# Linear layers for Q, K, V projections (can be combined for efficiency)
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
# Final linear layer after concatenation
self.W_o = nn.Linear(d_model, d_model, bias=False)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
"""Computes scaled dot-product attention."""
# MatMul QK^T
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided (for decoder self-attention)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Use large negative value
# SoftMax
attn_probs = F.softmax(attn_scores, dim=-1)
# MatMul Softmax(QK^T/sqrt(d_k)) * V
output = torch.matmul(attn_probs, V)
return output, attn_probs # Return probs for visualization later
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Performs the multi-head attention forward pass.
Args:
query (torch.Tensor): Query tensor, shape (batch_size, seq_len_q, d_model)
key (torch.Tensor): Key tensor, shape (batch_size, seq_len_k, d_model)
value (torch.Tensor): Value tensor, shape (batch_size, seq_len_v, d_model)
(seq_len_k and seq_len_v must be the same)
mask (torch.Tensor, optional): Mask tensor to prevent attention to certain positions.
Shape depends on application (e.g., padding mask, lookahead mask).
Defaults to None.
Returns:
torch.Tensor: Output tensor, shape (batch_size, seq_len_q, d_model)
"""
batch_size = query.size(0)
# 1. Linear Projections
# Project Q, K, V using the respective weight matrices W_q, W_k, W_v
# Input shape: (batch_size, seq_len, d_model)
# Output shape: (batch_size, seq_len, d_model)
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
# 2. Reshape for Multi-Head
# Reshape Q, K, V to separate the heads
# Original shape: (batch_size, seq_len, d_model)
# Target shape: (batch_size, num_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 3. Apply Scaled Dot-Product Attention (per head)
# The shapes are now (batch_size, num_heads, seq_len, d_k)
# Attention function operates on the last two dimensions (seq_len, d_k)
# If a mask is provided, it needs to be appropriately broadcastable.
# For example, a padding mask (batch_size, 1, 1, seq_len_k) or
# a lookahead mask (batch_size, 1, seq_len_q, seq_len_k).
attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask)
# attn_output shape: (batch_size, num_heads, seq_len_q, d_k)
# attn_probs shape: (batch_size, num_heads, seq_len_q, seq_len_k)
# 4. Concatenate Heads
# Reshape the attention output back to combine heads
# Transpose brings seq_len_q back before d_model components
# contiguous() ensures memory layout is suitable for view()
# Target shape: (batch_size, seq_len_q, d_model)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 5. Final Linear Projection
# Apply the final weight matrix W_o
# Input shape: (batch_size, seq_len_q, d_model)
# Output shape: (batch_size, seq_len_q, d_model)
output = self.W_o(attn_output)
return output # We usually only need the final output tensor during training/inference
__init__
): We set up four nn.Linear
layers. W_q
, W_k
, W_v
project the input embedding d_model
into Q, K, V vectors, also of size d_model
. The implementation detail here is that we project to d_model
first and then reshape, rather than projecting directly to d_k
per head. Both approaches are valid. W_o
is the final output transformation layer. We calculate d_k
(dimension per head) as d_model / num_heads
.scaled_dot_product_attention
. This encapsulates the core attention logic: Attention(Q,K,V)=softmax(dkQKT)V. It calculates attention scores, applies the scaling factor 1/dk, optionally applies a mask (setting masked positions to a large negative number before softmax), computes attention probabilities using softmax, and finally computes the weighted sum of Values. Note that PyTorch offers a highly optimized version (torch.nn.functional.scaled_dot_product_attention
) which should be preferred in production code for performance, but implementing it manually, as shown here, aids understanding.forward
):
forward
method accepts query
, key
, and value
tensors. For self-attention (the focus of this chapter), these three would be identical (coming from the same input sequence). We keep them separate to maintain generality, as this module can also be used for encoder-decoder cross-attention (covered later). An optional mask
can also be provided.query
, key
, value
tensors are passed through their respective linear layers (W_q
, W_k
, W_v
).(batch_size, seq_len, d_model)
, needs to be reshaped to (batch_size, num_heads, seq_len, d_k)
. This isolates the computations for each head. The .view()
method reshapes the tensor, and .transpose(1, 2)
swaps the num_heads
and seq_len
dimensions to prepare for the batch matrix multiplication within the attention function.scaled_dot_product_attention
method is called with the reshaped Q, K, V tensors. Batch matrix multiplication handles the computation across the batch and head dimensions simultaneously.(batch_size, num_heads, seq_len_q, d_k)
, needs to be combined. We first .transpose(1, 2)
back to (batch_size, seq_len_q, num_heads, d_k)
. .contiguous()
ensures the tensor is stored in a contiguous block of memory, which is sometimes required before calling .view()
. Finally, .view(batch_size, -1, self.d_model)
reshapes it back into the desired (batch_size, seq_len_q, d_model)
format, effectively concatenating the head outputs along the embedding dimension.W_o
to produce the module's output.We can visualize the data flow through the multi-head attention layer:
Data flow within the Multi-Head Attention module. B=Batch Size, Lq/Lk/Lv=Sequence Lengths for Q/K/V, dm=Model Dimension, h=Number of Heads, dk=Dimension per Head. For self-attention, Lq=Lk=Lv.
This implementation provides a concrete understanding of how multiple attention heads operate in parallel. Each head potentially focuses on different aspects of the input relationships by using separate projections, and their combined knowledge is integrated through the final linear layer. This layer is a fundamental component used repeatedly within the Transformer encoder and decoder stacks.
© 2025 ApX Machine Learning