A practical implementation of multi-head self-attention using PyTorch is demonstrated. Building this layer from scratch solidifies understanding of the data flow and tensor manipulations involved. This hands-on exercise assumes familiarity with basic PyTorch modules and tensor operations.Our goal is to create a MultiHeadAttention module that takes an input sequence, projects it into Queries (Q), Keys (K), and Values (V) for multiple heads, computes scaled dot-product attention for each head in parallel, concatenates the results, and applies a final linear projection.Defining the Module StructureWe'll define a Python class inheriting from torch.nn.Module. The constructor (__init__) will initialize the necessary linear layers for the initial Q, K, V projections and the final output projection. We also need to store the embedding dimension (d_model) and the number of attention heads (num_heads). An important constraint is that d_model must be divisible by num_heads so that the projected dimensions ($d_k$, $d_v$) are integers.import torch import torch.nn as nn import torch.nn.functional as F import math class MultiHeadAttention(nn.Module): """Implements the Multi-Head Attention mechanism.""" def __init__(self, d_model: int, num_heads: int): """ Args: d_model (int): The dimensionality of the input and output embeddings. num_heads (int): The number of attention heads. """ super().__init__() assert d_model % num_heads == 0, "d_model must be divisible by num_heads" self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # Dimension of keys/queries per head # Linear layers for Q, K, V projections (can be combined for efficiency) self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) # Final linear layer after concatenation self.W_o = nn.Linear(d_model, d_model, bias=False) def scaled_dot_product_attention(self, Q, K, V, mask=None): """Computes scaled dot-product attention.""" # MatMul QK^T attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # Apply mask if provided (for decoder self-attention) if mask is not None: attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Use large negative value # SoftMax attn_probs = F.softmax(attn_scores, dim=-1) # MatMul Softmax(QK^T/sqrt(d_k)) * V output = torch.matmul(attn_probs, V) return output, attn_probs # Return probs for visualization later def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: """ Performs the multi-head attention forward pass. Args: query (torch.Tensor): Query tensor, shape (batch_size, seq_len_q, d_model) (torch.Tensor): Key tensor, shape (batch_size, seq_len_k, d_model) value (torch.Tensor): Value tensor, shape (batch_size, seq_len_v, d_model) (seq_len_k and seq_len_v must be the same) mask (torch.Tensor, optional): Mask tensor to prevent attention to certain positions. Shape depends on application (e.g., padding mask, lookahead mask). Defaults to None. Returns: torch.Tensor: Output tensor, shape (batch_size, seq_len_q, d_model) """ batch_size = query.size(0) # 1. Linear Projections # Project Q, K, V using the respective weight matrices W_q, W_k, W_v # Input shape: (batch_size, seq_len, d_model) # Output shape: (batch_size, seq_len, d_model) Q = self.W_q(query) K = self.W_k(key) V = self.W_v(value) # 2. Reshape for Multi-Head # Reshape Q, K, V to separate the heads # Original shape: (batch_size, seq_len, d_model) # Target shape: (batch_size, num_heads, seq_len, d_k) Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2) # 3. Apply Scaled Dot-Product Attention (per head) # The shapes are now (batch_size, num_heads, seq_len, d_k) # Attention function operates on the last two dimensions (seq_len, d_k) # If a mask is provided, it needs to be appropriately broadcastable. # For example, a padding mask (batch_size, 1, 1, seq_len_k) or # a lookahead mask (batch_size, 1, seq_len_q, seq_len_k). attn_output, attn_probs = self.scaled_dot_product_attention(Q, K, V, mask) # attn_output shape: (batch_size, num_heads, seq_len_q, d_k) # attn_probs shape: (batch_size, num_heads, seq_len_q, seq_len_k) # 4. Concatenate Heads # Reshape the attention output back to combine heads # Transpose brings seq_len_q back before d_model components # contiguous() ensures memory layout is suitable for view() # Target shape: (batch_size, seq_len_q, d_model) attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # 5. Final Linear Projection # Apply the final weight matrix W_o # Input shape: (batch_size, seq_len_q, d_model) # Output shape: (batch_size, seq_len_q, d_model) output = self.W_o(attn_output) return output # We usually only need the final output tensor during training/inferenceCode WalkthroughInitialization (__init__): We set up four nn.Linear layers. W_q, W_k, W_v project the input embedding d_model into Q, K, V vectors, also of size d_model. The implementation detail here is that we project to d_model first and then reshape, rather than projecting directly to d_k per head. Both approaches are valid. W_o is the final output transformation layer. We calculate d_k (dimension per head) as d_model / num_heads.Scaled Dot-Product Attention: For clarity, we've included a separate method scaled_dot_product_attention. This encapsulates the core attention logic: $Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$. It calculates attention scores, applies the scaling factor $1/\sqrt{d_k}$, optionally applies a mask (setting masked positions to a large negative number before softmax), computes attention probabilities using softmax, and finally computes the weighted sum of Values. Note that PyTorch offers a highly optimized version (torch.nn.functional.scaled_dot_product_attention) which should be preferred in production code for performance, but implementing it manually, as shown here, aids understanding.Forward Pass (forward):Input: The forward method accepts query, key, and value tensors. For self-attention (the focus of this chapter), these three would be identical (coming from the same input sequence). We keep them separate to maintain generality, as this module can also be used for encoder-decoder cross-attention (covered later). An optional mask can also be provided.Projections: The input query, key, value tensors are passed through their respective linear layers (W_q, W_k, W_v).Reshaping: This is a critical step. The output of the projections, shape (batch_size, seq_len, d_model), needs to be reshaped to (batch_size, num_heads, seq_len, d_k). This isolates the computations for each head. The .view() method reshapes the tensor, and .transpose(1, 2) swaps the num_heads and seq_len dimensions to prepare for the batch matrix multiplication within the attention function.Attention Calculation: The scaled_dot_product_attention method is called with the reshaped Q, K, V tensors. Batch matrix multiplication handles the computation across the batch and head dimensions simultaneously.Concatenation/Reshape Back: The output from the attention heads, shape (batch_size, num_heads, seq_len_q, d_k), needs to be combined. We first .transpose(1, 2) back to (batch_size, seq_len_q, num_heads, d_k). .contiguous() ensures the tensor is stored in a contiguous block of memory, which is sometimes required before calling .view(). Finally, .view(batch_size, -1, self.d_model) reshapes it back into the desired (batch_size, seq_len_q, d_model) format, effectively concatenating the head outputs along the embedding dimension.Final Projection: This concatenated tensor is passed through the final linear layer W_o to produce the module's output.Visualizing the FlowWe can visualize the data flow through the multi-head attention layer:digraph G { rankdir=TB; node [shape=box, style="filled", fillcolor="#e9ecef", fontname="Helvetica"]; edge [fontname="Helvetica", fontsize=10]; subgraph cluster_input { label="Input Tensors"; style=filled; color="#dee2e6"; rank=same; Query [label="Query\n(B, Lq, dm)", shape=cylinder, fillcolor="#a5d8ff"]; Key [label="Key\n(B, Lk, dm)", shape=cylinder, fillcolor="#a5d8ff"]; Value [label="Value\n(B, Lv, dm)", shape=cylinder, fillcolor="#a5d8ff"]; } subgraph cluster_projections { label="1. Linear Projections"; style=filled; color="#dee2e6"; Wq [label="Linear Wq", fillcolor="#96f2d7"]; Wk [label="Linear Wk", fillcolor="#96f2d7"]; Wv [label="Linear Wv", fillcolor="#96f2d7"]; } subgraph cluster_reshape_split { label="2. Reshape & Split Heads"; style=filled; color="#dee2e6"; ReshapeQ [label="Reshape\n(B, h, Lq, dk)"]; ReshapeK [label="Reshape\n(B, h, Lk, dk)"]; ReshapeV [label="Reshape\n(B, h, Lv, dk)"]; } subgraph cluster_attention { label="3. Scaled Dot-Product Attention (Parallel Heads)"; style=filled; color="#dee2e6"; SDPA [label="SDPA\nper head", shape=parallelogram, fillcolor="#bac8ff"]; } subgraph cluster_reshape_concat { label="4. Concatenate Heads & Reshape"; style=filled; color="#dee2e6"; ConcatReshape [label="Transpose & Reshape\n(B, Lq, dm)"]; } subgraph cluster_final_proj { label="5. Final Linear Projection"; style=filled; color="#dee2e6"; Wo [label="Linear Wo", fillcolor="#ffd8a8"]; } subgraph cluster_output { label="Output Tensor"; style=filled; color="#dee2e6"; Output [label="Output\n(B, Lq, dm)", shape=cylinder, fillcolor="#ffec99"]; } Query -> Wq; Key -> Wk; Value -> Wv; Wq -> ReshapeQ; Wk -> ReshapeK; Wv -> ReshapeV; ReshapeQ -> SDPA [label=" Q"]; ReshapeK -> SDPA [label=" K"]; ReshapeV -> SDPA [label=" V"]; SDPA -> ConcatReshape [label="(B, h, Lq, dk)"]; ConcatReshape -> Wo; Wo -> Output; {rank=same; Wq; Wk; Wv;} {rank=same; ReshapeQ; ReshapeK; ReshapeV;} } Data flow within the Multi-Head Attention module. B=Batch Size, Lq/Lk/Lv=Sequence Lengths for Q/K/V, dm=Model Dimension, h=Number of Heads, dk=Dimension per Head. For self-attention, Lq=Lk=Lv.This implementation provides a concrete understanding of how multiple attention heads operate in parallel. Each head potentially focuses on different aspects of the input relationships by using separate projections, and their combined knowledge is integrated through the final linear layer. This layer is a fundamental component used repeatedly within the Transformer encoder and decoder stacks.