While Scaled Dot-Product Attention allows the model to weigh the importance of different tokens within a sequence, performing this calculation only once might force the attention mechanism to average diverse types of relationships. Imagine trying to understand a sentence like "The tired animal didn't cross the street because it was too wide." A single attention mechanism might struggle to simultaneously capture both the "tired animal" relationship and the "street width" relationship effectively when focusing on the word "it".
Multi-Head Attention addresses this by running the Scaled Dot-Product Attention process multiple times in parallel, each with different learned transformations of the original queries, keys, and values. This allows each "head" to potentially focus on different aspects or representation subspaces of the information.
Here's the step-by-step process:
Linear Projections: Instead of using a single set of Query (Q), Key (K), and Value (V) matrices, Multi-Head Attention first creates h different sets of these matrices, where h is the number of attention heads (a hyperparameter). For each head i (from 1 to h), the original input Q, K, and V matrices (often derived from the same input sequence embeddings in the case of self-attention) are projected using learned weight matrices: WiQ, WiK, and WiV.
Typically, the dimensions of these projected matrices are smaller than the original embedding dimension (dmodel). If the input embedding dimension is dmodel, each head often works with dimensions dk=dv=dmodel/h. This ensures that the total computational cost is similar to a single head attention with full dimensions. These weight matrices (WiQ,WiK,WiV) are unique for each head and are learned during the training process.
Parallel Attention Calculations: Each of these projected sets (Qi,Ki,Vi) is then fed into its own Scaled Dot-Product Attention mechanism simultaneously. This results in h separate output matrices, let's call them headi:
headi=Attention(Qi,Ki,Vi)=softmax(dkQiKiT)ViEach headi matrix captures attention information based on the specific projections learned by head i. Because the projections differ (WiQ,WiK,WiV are different for each i), each head can potentially learn to focus on different types of relationships or features within the input sequence.
Concatenation: The outputs from all h attention heads are concatenated together along the feature dimension. If each headi has dimension dv, the concatenated matrix will have dimension h×dv. Since we typically set dv=dmodel/h, the dimension of the concatenated matrix becomes dmodel, matching the original input embedding dimension.
Concat(head1,head2,...,headh)Final Linear Projection: This concatenated output is then passed through one final linear projection layer, parameterized by another learned weight matrix WO. This projection mixes the information learned by the different heads and produces the final output of the Multi-Head Attention layer, which typically has the dimension dmodel.
MultiHead(Q,K,V)=Concat(head1,...,headh)WOThis entire Multi-Head Attention block can then be used as a component within the larger Transformer architecture, replacing the single Scaled Dot-Product Attention mechanism.
The following diagram illustrates the flow of information through a Multi-Head Attention block with h heads.
This diagram shows how input Q, K, and V matrices are first projected independently for each of the h attention heads. Scaled Dot-Product Attention is then applied to each projected set in parallel. The resulting attention outputs are concatenated and passed through a final linear layer to produce the Multi-Head Attention output.
By allowing different heads to learn different projection matrices (WiQ,WiK,WiV,WO), Multi-Head Attention enables the model to jointly attend to information from different representation subspaces at different positions, leading to a richer and more effective representation compared to using a single attention mechanism.
© 2025 ApX Machine Learning