Now that we've covered the theory behind Self-Attention and the Scaled Dot-Product Attention mechanism, let's translate that understanding into code. This practical exercise focuses on implementing the core attention calculation. Remember, this mechanism is fundamental to how Transformers process information, allowing the model to weigh the importance of different elements in the input sequence relative to each other.
We will implement the following formula:
Attention(Q,K,V)=softmax(dkQKT)VThis function takes Query (Q), Key (K), and Value (V) matrices as input, along with the dimension of the key vectors (dk) for scaling. Optionally, it can also handle a mask to prevent attention to certain positions (like padding tokens or future tokens in a decoder).
We'll use PyTorch for this implementation, but the concepts translate directly to other deep learning frameworks like TensorFlow. Ensure you have PyTorch installed. We'll also need math
for the square root calculation.
import torch
import torch.nn.functional as F
import math
Let's define a Python function scaled_dot_product_attention
that performs the calculation. It will accept tensors for Q, K, V, and an optional mask
.
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Calculates the Scaled Dot-Product Attention.
Args:
query (torch.Tensor): Query tensor; shape (batch_size, ..., seq_len_q, d_k)
key (torch.Tensor): Key tensor; shape (batch_size, ..., seq_len_k, d_k)
value (torch.Tensor): Value tensor; shape (batch_size, ..., seq_len_v, d_v)
Note: seq_len_k and seq_len_v must be the same.
mask (torch.Tensor, optional): Mask tensor; shape must be broadcastable
to (batch_size, ..., seq_len_q, seq_len_k).
Defaults to None.
Returns:
torch.Tensor: Output tensor; shape (batch_size, ..., seq_len_q, d_v)
torch.Tensor: Attention weights; shape (batch_size, ..., seq_len_q, seq_len_k)
"""
# Get the dimension of the key vectors
d_k = query.size(-1)
# 1. Calculate dot products: Q * K^T
# Result shape: (batch_size, ..., seq_len_q, seq_len_k)
attention_scores = torch.matmul(query, key.transpose(-2, -1))
# 2. Scale the scores
attention_scores = attention_scores / math.sqrt(d_k)
# 3. Apply the mask (if provided)
# The mask indicates positions to ignore (e.g., padding).
# We add a large negative number (-1e9) to these positions before softmax.
if mask is not None:
# Ensure mask has compatible shape
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
# 4. Apply softmax to get attention weights
# Softmax is applied on the last dimension (seq_len_k)
# Result shape: (batch_size, ..., seq_len_q, seq_len_k)
attention_weights = F.softmax(attention_scores, dim=-1)
# 5. Multiply weights by Value vectors V
# Result shape: (batch_size, ..., seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
Let's break down the steps within the function:
torch.matmul
handles the batching and matrix multiplication. The key.transpose(-2, -1)
operation swaps the last two dimensions of the key tensor, effectively transposing the key matrices for the multiplication. This step calculates the raw alignment scores between queries and keys.mask
is provided, we apply it here. The mask typically has 0
s where attention should be prevented (e.g., padding tokens or future positions in a sequence) and 1
s elsewhere. We use masked_fill
to replace the scores at masked positions (mask == 0
) with a very large negative number (-1e9
). When softmax is applied next, these positions will receive near-zero probability.F.softmax
function is applied along the last dimension (the key sequence length dimension, seq_len_k
). This converts the scaled scores into probability distributions, representing the attention weights. Each query position will have weights summing to 1 across all key positions.The function returns both the final output tensor and the attention weights, which can be useful for analysis and visualization.
Let's create some sample tensors and see the function in action. We'll assume a batch size of 1, a sequence length of 4, and embedding dimensions (dk, dv) of 8. In self-attention, Q, K, and V often derive from the same input sequence, so seq_len_q
, seq_len_k
, and seq_len_v
are typically the same.
# Example Parameters
batch_size = 1
seq_len = 4
d_k = 8 # Dimension of Key/Query
d_v = 8 # Dimension of Value
# Create random Query, Key, Value tensors
# In a real model, these would come from input embeddings projected by linear layers
query = torch.randn(batch_size, seq_len, d_k)
key = torch.randn(batch_size, seq_len, d_k)
value = torch.randn(batch_size, seq_len, d_v)
# Calculate attention
output, attention_weights = scaled_dot_product_attention(query, key, value)
print("Input Query Shape:", query.shape)
print("Input Key Shape:", key.shape)
print("Input Value Shape:", value.shape)
print("\nOutput Shape:", output.shape)
print("Attention Weights Shape:", attention_weights.shape)
print("\nSample Attention Weights (first batch element):\n", attention_weights[0])
You should see output similar to this (values will differ due to randomness):
Input Query Shape: torch.Size([1, 4, 8])
Input Key Shape: torch.Size([1, 4, 8])
Input Value Shape: torch.Size([1, 4, 8])
Output Shape: torch.Size([1, 4, 8])
Attention Weights Shape: torch.Size([1, 4, 4])
Sample Attention Weights (first batch element):
tensor([[0.1813, 0.3056, 0.3317, 0.1814],
[0.2477, 0.2080, 0.3401, 0.2042],
[0.2880, 0.1807, 0.2523, 0.2790],
[0.3139, 0.1774, 0.2614, 0.2473]])
Notice that the output shape (1, 4, 8)
matches the query and value sequence length and the value dimension (dv). The attention weights shape (1, 4, 4)
represents the attention scores from each of the 4 query positions to each of the 4 key positions. Each row in the sample attention weights sums to approximately 1.
Visualizing the attention weights can provide insights into what parts of the input sequence the model focuses on when processing a specific element. Let's use a simple heatmap for the attention_weights
we just calculated.
Attention weights visualized as a heatmap. Each cell (i, j) shows the attention weight from Query position i to Key position j. Darker blue indicates higher attention.
This visualization shows how much each query position (row) attends to each key position (column). In a real application, like translating "hello world" to French, you might see that when generating the French word for "world", the attention mechanism focuses heavily on the input word "world".
In this section, you implemented the Scaled Dot-Product Attention, the core computational block for attention in Transformers. You saw how to compute scores between queries and keys, scale them, optionally apply masking, normalize using softmax to get weights, and finally compute a weighted sum of values. This function is the building block used within the Multi-Head Attention mechanism discussed earlier, allowing Transformers to effectively process sequence information.
© 2025 ApX Machine Learning