Masterclass
We begin our implementation journey with the fundamental building block of the Transformer's attention mechanism: Scaled Dot-Product Attention. As discussed in Chapter 4, this mechanism allows the model to weigh the importance of different parts of the input sequence when processing a specific element. Instead of relying on recurrence, it computes attention scores based on the interaction between queries, keys, and values derived from the input.
The core calculation for Scaled Dot-Product Attention is defined as:
Attention(Q,K,V)=softmax(dk​​QKT​)VLet's break down the components and the implementation steps:
Query (Q), Key (K), Value (V) Matrices: These matrices are typically projections of the input embeddings. For a given input sequence element (represented as a vector), we generate:
query
vector: Represents the current element seeking information.key
vector: Represents an element providing information, used for calculating compatibility with the query.value
vector: Represents the actual content of the element providing information.
If we have a batch of sequences, Q, K, and V become matrices where each row corresponds to an element in the sequence. Their dimensions are typically [batch_size,seq_len,dmodel​] or after projection for a single attention head, [batch_size,seq_len,dk​] for Q and K, and [batch_size,seq_len,dv​] for V. Often, dk​=dv​.Compute Dot Products (QKT): The first step is to compute the dot product between the Query matrix Q and the transpose of the Key matrix KT. This operation calculates a score for how much each query should attend to each key. A higher dot product suggests higher relevance or compatibility between the query and key. The resulting matrix, often called scores
or energy
, will have dimensions [batch_size,seq_lenq​,seq_lenk​], where seq_lenq​ is the sequence length of queries and seq_lenk​ is the sequence length of keys (they are often the same in self-attention).
Scale (dk​​...​): The scores are then scaled down by dividing by the square root of the dimension of the key vectors, dk​​. This scaling is important for stabilizing the training process. Without it, for large values of dk​, the dot products might grow very large in magnitude. Large inputs to the softmax function can result in extremely small gradients, making learning difficult. Scaling ensures the variance of the inputs to softmax remains reasonable.
Apply Mask (Optional): In many scenarios, we need to prevent attention to certain positions. This is achieved through masking before the softmax step.
True
or 1
). We add a large negative number (like -1e9 or negative infinity) to the scores at these positions.Apply Softmax: The softmax function is applied row-wise to the scaled (and potentially masked) scores. This converts the scores into probability distributions, where each value represents the attention weight assigned by a query to a key. The weights for each query sum up to 1. The resulting matrix, often called attention_weights
, has dimensions [batch_size,seq_lenq​,seq_lenk​].
Multiply by Values (...V): Finally, the attention weights matrix is multiplied by the Value matrix V. This step computes a weighted sum of the value vectors, where the weights are determined by the attention probabilities. Elements that received higher attention weights contribute more to the output. The output of the Scaled Dot-Product Attention layer has dimensions [batch_size,seq_lenq​,dv​].
Let's translate these steps into a PyTorch function. We'll assume the inputs query
, key
, and value
are 3D tensors representing batches of sequences, potentially already projected for a specific attention head.
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Computes Scaled Dot-Product Attention.
Args:
query: Query tensor; shape (batch_size, num_heads, seq_len_q, d_k)
or (batch_size, seq_len_q, d_k) if single head.
key: Key tensor; shape (batch_size, num_heads, seq_len_k, d_k)
or (batch_size, seq_len_k, d_k) if single head.
value: Value tensor; shape (batch_size, num_heads, seq_len_v, d_v)
or (batch_size, seq_len_v, d_v) if single head.
Note: seq_len_k and seq_len_v must be the same.
mask: Optional mask tensor; shape should be broadcastable to
(batch_size, num_heads, seq_len_q, seq_len_k).
Positions with True or 1 will be masked (set to -inf).
Returns:
A tuple containing:
- output: The attention output tensor;
shape (batch_size, num_heads, seq_len_q, d_v)
or (batch_size, seq_len_q, d_v) if single head.
- attention_weights: The attention weights tensor;
shape (batch_size, num_heads, seq_len_q, seq_len_k)
or (batch_size, seq_len_q, seq_len_k) if single head.
"""
# Ensure dimensions are compatible for matrix multiplication
# K needs shape (..., d_k, seq_len_k) for matmul with Q (..., seq_len_q, d_k)
# Resulting shape: (..., seq_len_q, seq_len_k)
d_k = query.size(-1)
scores = (torch.matmul(query, key.transpose(-2, -1))
/ math.sqrt(d_k))
# Apply mask if provided (set masked positions to a large negative value)
if mask is not None:
# Ensure mask has compatible dimensions or can be broadcasted
# Common mask shapes: (batch_size, 1, 1, seq_len_k) for padding mask
# (batch_size, 1, seq_len_q, seq_len_k) for combined masks
# We add a large negative value instead of using boolean mask directly
# to ensure compatibility with various PyTorch versions and operations.
# Where mask is True (or 1), we want to replace scores with -inf.
scores = scores.masked_fill(mask == True, float('-inf'))
# Or use a large negative number like -1e9
# Apply softmax to get attention probabilities
# Softmax is applied on the last dimension (seq_len_k)
attention_weights = F.softmax(scores, dim=-1)
# Check for potential NaNs after softmax, which can happen if all scores in a row are -inf
# This might indicate an issue with masking or input data
if torch.isnan(attention_weights).any():
print("Warning: NaNs detected in attention weights. "
"Check masking or input data.")
# Optionally handle NaNs, e.g., by setting them to 0,
# although this might hide underlying issues.
# attention_weights = torch.nan_to_num(attention_weights)
# Multiply weights by values
# Resulting shape: (..., seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
# Example usage (assuming single head for simplicity)
batch_size = 2
seq_len_q = 5 # Query sequence length
seq_len_k = 7 # Key/Value sequence length
d_k = 64 # Dimension of keys/queries
d_v = 128 # Dimension of values
# Dummy tensors
query_tensor = torch.randn(batch_size, seq_len_q, d_k)
key_tensor = torch.randn(batch_size, seq_len_k, d_k)
value_tensor = torch.randn(batch_size, seq_len_k, d_v) # seq_len_k == seq_len_v
# Example padding mask (masking last 2 elements of key/value sequence)
padding_mask = torch.zeros(batch_size, 1, seq_len_k, dtype=torch.bool)
padding_mask[:, :, -2:] = True # Mask positions 5 and 6
# Compute attention
output_tensor, attention_weights_tensor = scaled_dot_product_attention(
query_tensor,
key_tensor,
value_tensor,
mask=padding_mask
)
print("Output shape:", output_tensor.shape) # Expected: [2, 5, 128]
print("Attention weights shape:", attention_weights_tensor.shape) # Expected: [2, 5, 7]
# Verify masking effect (weights for last two keys should be near zero)
print("Attention weights for first query in batch 0 "
"(masked last two keys):")
print(attention_weights_tensor[0, 0, :])
This function encapsulates the core logic. Notice how the mask needs to be applied before the softmax. The use of masked_fill
with a large negative number effectively prevents masked positions from contributing to the weighted sum after the softmax normalization. The function returns both the final weighted output and the attention weights themselves, which can be useful for analysis or visualization (as we will see in Chapter 23).
This fundamental building block will now be used within the Multi-Head Attention mechanism, which we implement next. Multi-Head Attention runs this scaled dot-product attention multiple times in parallel with different learned projections of the queries, keys, and values.
© 2025 ApX Machine Learning