As we delve into the decoder stack, recall its primary role: generating the output sequence one element at a time. In tasks like machine translation, this means producing the translated sentence word by word. This sequential generation process imposes a specific requirement on the self-attention mechanism used within the decoder. Unlike the encoder, which can process the entire input sequence simultaneously, the decoder must predict the next token based only on the tokens generated so far. It cannot look ahead into the future tokens of the output sequence it is currently building.
This is where Masked Multi-Head Self-Attention comes into play. It's a modification of the multi-head self-attention mechanism we discussed in Chapter 2, specifically designed to enforce this "no looking ahead" constraint during training and inference.
Consider the training process for a sequence-to-sequence model. We typically use "teacher forcing," where the decoder is fed the correct target sequence (shifted right, often with a start-of-sequence token) as input to predict the next token at each position. For example, to predict the third word of the target translation, the decoder receives the first two correct words.
If we used standard self-attention here, the attention mechanism at position i
would be able to incorporate information from all positions in the target sequence, including positions i+1
, i+2
, etc. This would be like cheating; the model could simply copy the next word from the input instead of learning to predict it based on the preceding words and the encoder's output. The model needs to learn the conditional probability distribution P(outputt∣output<t,encoder_output). Allowing attention to future tokens breaks this conditional dependency.
During inference (when generating a new sequence), future tokens are unknown anyway. Therefore, the attention mechanism must be consistent between training and inference, only attending to previously generated tokens.
Masked multi-head self-attention achieves this constraint by modifying the scaled dot-product attention calculation within each attention head. Before applying the softmax function to the attention scores, a mask is added.
Recall the scaled dot-product attention formula:
scores=dkQKTIn masked self-attention, we modify this:
masked_scores=dkQKT+M AttentionWeights=softmax(masked_scores) Attention(Q,K,V)=AttentionWeightsVThe mask M is typically a matrix where elements corresponding to positions the model is not allowed to attend to are set to a very large negative number (effectively negative infinity), and elements corresponding to allowed positions are set to zero.
For a target sequence of length L, the mask M would be an L×L matrix. For the i-th token (row i of the Q matrix), the mask ensures it only attends to tokens j where j≤i (columns 0 to i of the K matrix). The entries Mij where j>i are set to −∞, while entries where j≤i are set to 0.
When a large negative number is added to the attention scores for future positions, the subsequent softmax function assigns these positions a probability extremely close to zero. This effectively prevents any information flow from future tokens into the representation of the current token.
Let's visualize the mask for a sequence of length 4:
Attend To --> Pos 1 Pos 2 Pos 3 Pos 4
Query From Pos 1: [ 0 -inf -inf -inf ]
Query From Pos 2: [ 0 0 -inf -inf ]
Query From Pos 3: [ 0 0 0 -inf ]
Query From Pos 4: [ 0 0 0 0 ]
Here, 0
represents an allowed attention connection (after adding this mask, the original score is unchanged), and -inf
represents a masked connection (after adding this mask, the score becomes effectively negative infinity).
The following diagram illustrates the allowed attention connections for the token at position 3 ("predicts") in a hypothetical output sequence "The model predicts well", assuming it's being generated step-by-step.
The diagram shows that the query from position 3 ("predicts") can attend to keys/values from positions 1, 2, and 3, but attention to position 4 ("well") is masked out.
This masking procedure is applied independently within each attention head of the multi-head structure. The overall process remains the same as standard multi-head attention:
The only difference compared to the encoder's self-attention is step 3, the application of the mask.
Masked Multi-Head Self-Attention is a modification of the standard self-attention mechanism, indispensable for the decoder component of the Transformer. By preventing positions from attending to subsequent positions in the output sequence, it ensures that the model's predictions are auto-regressive, meaning the prediction for the current step depends only on the previously generated steps and the input sequence. This is achieved by adding a mask matrix (containing zeros and negative infinities) to the attention scores before the softmax calculation, effectively zeroing out the weights for future tokens. This mechanism is applied within each head of the multi-head attention structure in the decoder's self-attention layers. This contrasts with the encoder's self-attention, which allows each position to attend to all positions in the input sequence.
© 2025 ApX Machine Learning