Masterclass
As we transition from recurrent architectures, the central mechanism replacing sequential processing in the Transformer is attention. Instead of relying on hidden states passed step-by-step, attention allows the model to directly weigh the significance of different parts of the input sequence when processing a specific part. The fundamental unit for this is Scaled Dot-Product Attention.
Imagine you're trying to understand the meaning of the word "bank" in the sentence "The river bank was eroding." To disambiguate "bank", you'd naturally pay more attention to "river" than to "eroding" or "was". Self-attention formalizes this intuition. For each element in the sequence (e.g., each word's embedding), we want to compute a representation that incorporates information from other elements, weighted by their relevance.
To achieve this, Scaled Dot-Product Attention operates on three inputs derived from the input sequence embeddings:
In practice, Q, K, and V are often generated by projecting the input embeddings (or the outputs of a previous layer) through separate linear layers with learned weights. Let the input sequence embeddings have dimension dmodel​, and the dimension of the key and query vectors be dk​. The dimension of the value vectors is dv​ (often dk​=dv​, but not necessarily).
The computation proceeds in several steps:
The first step is to measure the compatibility or similarity between each query and all keys. This is done using the dot product. For a single query q and all keys K, we compute q⋅ki​ for each key ki​. For the full matrices Q and K, this is efficiently computed as a matrix multiplication:
Scores=QKTThe resulting matrix contains raw scores where Scoresij​ represents the similarity between query i and key j. A higher dot product suggests greater relevance between the query and the key.
The dot products can grow large in magnitude, especially for larger dimensions dk​. Large values pushed into the softmax function (next step) can result in extremely small gradients, hindering learning. To counteract this, the scores are scaled down by the square root of the key dimension, dk​​:
ScaledScores=dk​​QKT​This scaling helps stabilize the gradients and makes training more reliable. The choice of dk​​ is based on the assumption that the components of Q and K are independent random variables with zero mean and unit variance. Under this assumption, the dot product q⋅k=∑i=1dk​​qi​ki​ has a mean of 0 and a variance of dk​. Scaling by dk​​ brings the variance back to 1, keeping the inputs to the softmax in a reasonable range.
To convert the scaled scores into a probability distribution representing the attention weights, the softmax function is applied row-wise to the scaled scores matrix:
AttentionWeights=softmax(dk​​QKT​)Each row of the Attention Weights
matrix now sums to 1, and each element Weightsij​ indicates how much attention query i should pay to value j.
Finally, the attention weights are used to compute a weighted sum of the value vectors. This means multiplying the Attention Weights
matrix by the Values
matrix V:
The resulting Output
matrix contains the attention-weighted representations. Each row Outputi​ is a vector that is a weighted combination of all value vectors in V, where the weights are determined by the similarity of query i to all keys. This output vector effectively incorporates context from the entire sequence, weighted by relevance.
The complete formula for Scaled Dot-Product Attention is thus:
Attention(Q,K,V)=softmax(dk​​QKT​)VLet's visualize the data flow:
Flow diagram of the Scaled Dot-Product Attention mechanism. Input embeddings are projected into Q, K, V matrices, which are then processed through matrix multiplications, scaling, and softmax to produce weighted output representations.
Here's a simplified implementation using PyTorch to illustrate the core calculation:
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Computes Scaled Dot-Product Attention.
Args:
query: Query tensor (Batch, SeqLen_Q, Dim_K)
key: Key tensor (Batch, SeqLen_KV, Dim_K)
value: Value tensor (Batch, SeqLen_KV, Dim_V)
mask: Optional mask tensor (Batch, 1, SeqLen_Q, SeqLen_KV)
Returns:
Output tensor (Batch, SeqLen_Q, Dim_V),
Attention weights (Batch, SeqLen_Q, SeqLen_KV)
"""
dim_k = query.size(-1)
# MatMul QK^T: (Batch, SeqLen_Q, Dim_K) x (Batch, Dim_K, SeqLen_KV)
# -> (Batch, SeqLen_Q, SeqLen_KV)
scores = torch.matmul(query, key.transpose(-2, -1))
# Scale
scaled_scores = scores / math.sqrt(dim_k)
# Optional Masking (e.g., for padding or preventing future peeking
# in decoders)
if mask is not None:
# Apply a large negative value where mask is True (or 0)
scaled_scores = scaled_scores.masked_fill(mask == 0, -1e9)
# Using -1e9 for numerical stability
# Softmax
# Softmax operates on the last dimension (SeqLen_KV)
attention_weights = F.softmax(scaled_scores, dim=-1)
# MatMul Weights * V: (Batch, SeqLen_Q, SeqLen_KV) x
# (Batch, SeqLen_KV, Dim_V) -> (Batch, SeqLen_Q, Dim_V)
output = torch.matmul(attention_weights, value)
return output, attention_weights
# Example usage (simplified dimensions)
batch_size = 1
seq_len_q = 3 # Length of query sequence
seq_len_kv = 5 # Length of key/value sequence
dim_k = 8
dim_v = 10
# Dummy input tensors
q = torch.randn(batch_size, seq_len_q, dim_k)
k = torch.randn(batch_size, seq_len_kv, dim_k)
v = torch.randn(batch_size, seq_len_kv, dim_v)
# Compute attention
output, weights = scaled_dot_product_attention(q, k, v)
print("Output shape:", output.shape)
print("Attention weights shape:", weights.shape)
# Example output:
# Output shape: torch.Size([1, 3, 10])
# Attention weights shape: torch.Size([1, 3, 5])
In this code:
math.sqrt(dim_k)
.mask
argument is included. Masks are important in Transformers, for instance, to prevent the model from attending to padding tokens or, in the decoder, to prevent attending to future tokens (look-ahead mask). We apply the mask before the softmax by setting masked positions to a very large negative number, ensuring they get near-zero probability after softmax.F.softmax
computes the attention weights.value
tensor to get the output.This mechanism forms the heart of the Transformer's ability to capture dependencies regardless of their distance in the sequence, a significant advantage over the sequential bottlenecks inherent in RNNs. However, a single attention calculation might focus only on one type of relationship. To capture diverse relationships simultaneously, the Transformer employs Multi-Head Attention, which we will examine next.
© 2025 ApX Machine Learning