Now that we've covered the theory behind Self-Attention and the Scaled Dot-Product Attention mechanism, let's translate that understanding into code. This practical exercise focuses on implementing the core attention calculation. Remember, this mechanism is fundamental to how Transformers process information, allowing the model to weigh the importance of different elements in the input sequence relative to each other.
We will implement the following formula:
Attention(Q,K,V)=softmax(dkQKT)VThis function takes Query (Q), Key (K), and Value (V) matrices as input, along with the dimension of the key vectors (dk) for scaling. Optionally, it can also handle a mask to prevent attention to certain positions (like padding tokens or future tokens in a decoder).
We'll use PyTorch for this implementation, but the concepts translate directly to other deep learning frameworks like TensorFlow. Ensure you have PyTorch installed. We'll also need math
for the square root calculation.
import torch
import torch.nn.functional as F
import math
Let's define a Python function scaled_dot_product_attention
that performs the calculation. It will accept tensors for Q, K, V, and an optional mask
.
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Calculates the Scaled Dot-Product Attention.
Args:
query (torch.Tensor): Query tensor; shape (batch_size, ..., seq_len_q, d_k)
(torch.Tensor): Key tensor; shape (batch_size, ..., seq_len_k, d_k)
value (torch.Tensor): Value tensor; shape (batch_size, ..., seq_len_v, d_v)
Note: seq_len_k and seq_len_v must be the same.
mask (torch.Tensor, optional): Mask tensor; shape must be broadcastable
to (batch_size, ..., seq_len_q, seq_len_k).
Defaults to None.
Returns:
torch.Tensor: Output tensor; shape (batch_size, ..., seq_len_q, d_v)
torch.Tensor: Attention weights; shape (batch_size, ..., seq_len_q, seq_len_k)
"""
# Get the dimension of the vectors
d_k = query.size(-1)
# 1. Calculate dot products: Q * K^T
# Result shape: (batch_size, ..., seq_len_q, seq_len_k)
```python
attention_scores = torch.matmul(query, key.transpose(-2, -1))
# 2. Scale the scores
attention_scores = attention_scores / math.sqrt(d_k)
# 3. Apply the mask (if provided)
# The mask indicates positions to ignore (e.g., padding).
# We add a large negative number (-1e9) to these positions before softmax.
if mask is not None:
# Ensure mask has compatible shape
attention_scores = attention_scores.masked_fill(mask == 0, -1e9)
# 4. Apply softmax to get attention weights
# Softmax is applied on the last dimension (seq_len_k)
# Result shape: (batch_size, ..., seq_len_q, seq_len_k)
attention_weights = F.softmax(attention_scores, dim=-1)
# 5. Multiply weights by Value vectors V
# Result shape: (batch_size, ..., seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
Let's break down the steps within the function:
1. **Matrix Multiplication ($QK^T$)**: We compute the dot product between each query vector and all key vectors. `torch.matmul` handles the batching and matrix multiplication. The `key.transpose(-2, -1)` operation swaps the last two dimensions of the key tensor, effectively transposing the key matrices for the multiplication. This step calculates the raw alignment scores between queries and keys.
2. **Scaling**: The scores are divided by the square root of the dimension ($d_k$). As discussed previously, this scaling prevents the dot products from becoming too large, which could push the softmax function into regions with very small gradients, hindering learning.
3. **Masking (Optional)**: If a `mask` is provided, we apply it here. The mask typically has `0`s where attention should be prevented (e.g., padding tokens or future positions in a sequence) and `1`s elsewhere. We use `masked_fill` to replace the scores at masked positions (`mask == 0`) with a very large negative number (`-1e9`). When softmax is applied next, these positions will receive near-zero probability.
4. **Softmax**: The `F.softmax` function is applied along the last dimension (the sequence length dimension, `seq_len_k`). This converts the scaled scores into probability distributions, representing the attention weights. Each query position will have weights summing to 1 across all key positions.
5. **Matrix Multiplication (Weights * V)**: Finally, the attention weights are multiplied by the Value ($V$) tensor. This computes a weighted sum of the value vectors, where the weights are determined by the attention distribution. The result is the output of the attention mechanism, representing the input sequence with contextually relevant information emphasized for each query position.
The function returns both the final output tensor and the attention weights, which can be useful for analysis and visualization.
### Example Usage
Let's create some sample tensors and see the function in action. We'll assume a batch size of 1, a sequence length of 4, and embedding dimensions ($d_k$, $d_v$) of 8. In self-attention, Q, K, and V often derive from the same input sequence, so `seq_len_q`, `seq_len_k`, and `seq_len_v` are typically the same.
```python
# Example Parameters
batch_size = 1
seq_len = 4
d_k = 8 # Dimension of Key/Query
d_v = 8 # Dimension of Value
# Create random Query, Value tensors
# In a real model, these would come from input embeddings projected by linear layers
query = torch.randn(batch_size, seq_len, d_k)
key = torch.randn(batch_size, seq_len, d_k)
value = torch.randn(batch_size, seq_len, d_v)
# Calculate attention
output, attention_weights = scaled_dot_product_attention(query, key, value)
print("Input Query Shape:", query.shape)
print("Input Shape:", key.shape)
print("Input Value Shape:", value.shape)
print("\nOutput Shape:", output.shape)
print("Attention Weights Shape:", attention_weights.shape)
print("\nSample Attention Weights (first batch element):\n", attention_weights[0])
You should see output similar to this (values will differ due to randomness):
Input Query Shape: torch.Size([1, 4, 8])
Input Shape: torch.Size([1, 4, 8])
Input Value Shape: torch.Size([1, 4, 8])
Output Shape: torch.Size([1, 4, 8])
Attention Weights Shape: torch.Size([1, 4, 4])
Sample Attention Weights (first batch element):
tensor([[0.1813, 0.3056, 0.3317, 0.1814],
[0.2477, 0.2080, 0.3401, 0.2042],
[0.2880, 0.1807, 0.2523, 0.2790],
[0.3139, 0.1774, 0.2614, 0.2473]])
Notice that the output shape (1, 4, 8)
matches the query and value sequence length and the value dimension (dv). The attention weights shape (1, 4, 4)
represents the attention scores from each of the 4 query positions to each of the 4 key positions. Each row in the sample attention weights sums to approximately 1.
Visualizing the attention weights can provide insights into what parts of the input sequence the model focuses on when processing a specific element. Let's use a simple heatmap for the attention_weights
we just calculated.
Attention weights visualized as a heatmap. Each cell (i, j) shows the attention weight from Query position i to Key position j. Darker blue indicates higher attention.
This visualization shows how much each query position (row) attends to each position (column). In a real application, like translating "hello world" to French, you might see that when generating the French word for "world", the attention mechanism focuses heavily on the input word "world".
In this section, you implemented the Scaled Dot-Product Attention, the core computational block for attention in Transformers. You saw how to compute scores between queries and keys, scale them, optionally apply masking, normalize using softmax to get weights, and finally compute a weighted sum of values. This function is the building block used within the Multi-Head Attention mechanism discussed earlier, allowing Transformers to effectively process sequence information.
Was this section helpful?
© 2025 ApX Machine Learning