After computing the scaled dot-product attention in parallel for each of the h heads, we are left with h distinct output vectors, often denoted as head1,head2,...,headh. Recall from the previous section on parallel computations that each headi is the result of applying the attention mechanism using projected versions of the original queries, keys, and values:
headi=Attention(XWiQ,XWiK,XWiV)
where X represents the input sequence embeddings (or the output from the previous layer), and WiQ, WiK, WiV are the parameter matrices for the i-th head. Each headi has a dimension of dv, where dv=dmodel/h. This division ensures that the computational cost across all heads remains comparable to a single head attention with the full dmodel dimension.
However, the subsequent layers in the Transformer, such as the position-wise feed-forward network and the residual connections, expect a single input tensor with the model's hidden dimension, dmodel. Therefore, we need a mechanism to consolidate the outputs from these parallel heads back into a single representation.
The first step in integrating the information captured by the different heads is straightforward: concatenation. The output vectors from all h heads are concatenated along their feature dimension.
If each headi is a matrix of shape (sequence length, dv), the concatenation operation stacks these matrices side-by-side, resulting in a new matrix of shape (sequence length, h×dv). Since we defined dv=dmodel/h, the resulting concatenated matrix has the dimensions (sequence length, dmodel).
Concat(head1,head2,...,headh)∈Rseq_len×(h⋅dv)=Rseq_len×dmodel
This operation effectively aggregates the insights learned from the different representational subspaces that each head focused on.
Flow showing outputs from individual attention heads being concatenated and then projected through a final linear layer.
While concatenation aggregates the head outputs, the resulting tensor simply places the specialized representations next to each other. To allow these representations to interact and be combined effectively, and to ensure the output dimension matches the required dmodel for subsequent layers (especially for residual connections), a final linear projection is applied.
This projection is implemented using another learned weight matrix, WO, with dimensions (h×dv,dmodel), which simplifies to (dmodel,dmodel). The concatenated output is multiplied by this weight matrix:
MultiHeadOutput=Concat(head1,...,headh)WO
This final linear transformation plays an important role:
In essence, the multi-head mechanism first splits the model's representation capacity into multiple subspaces (h heads, each with dimension dv), allows each head to specialize in attending to different aspects of the input within its subspace, and then uses concatenation followed by a linear projection (WO) to merge these specialized views back into a single, richer representation of dimension dmodel.
This structure allows the model to capture a variety of relational patterns (e.g., short-range dependencies, long-range dependencies, syntactic relationships) simultaneously, which would be more difficult for a single attention mechanism. The parameters of the projection matrices (WiQ,WiK,WiV for each head) and the final output projection matrix (WO) are all learned end-to-end during model training.
© 2025 ApX Machine Learning