Having established the theoretical underpinnings of the attention mechanism, specifically Scaled Dot-Product Attention, let's translate this theory into practice. This section provides a hands-on implementation using PyTorch, focusing on the core calculation. Understanding this implementation is fundamental before assembling more complex structures like Multi-Head Attention.
Recall the formula: Attention(Q,K,V)=softmax(dkQKT)V
We'll implement this step by step.
First, ensure you have PyTorch installed. We'll need the basic torch
library and the math
module for the square root operation.
import torch
import torch.nn.functional as F
import math
import plotly.graph_objects as go # For visualization
import numpy as np # For visualization data handling
Let's define a function scaled_dot_product_attention
that takes Queries (Q), Keys (K), Values (V), and an optional mask as input.
For simplicity in this initial example, we'll assume Q, K, and V have shapes like [batch_size, sequence_length, dimension]
. In a full Transformer, these tensors often have an additional 'head' dimension, which we'll address later. The dimension dk corresponds to the last dimension of the Key tensor.
def scaled_dot_product_attention(query, key, value, mask=None):
"""
Computes Scaled Dot-Product Attention.
Args:
query: Query tensor; shape (batch_size, seq_len_q, d_k)
key: Key tensor; shape (batch_size, seq_len_k, d_k)
value: Value tensor; shape (batch_size, seq_len_k, d_v)
Note: seq_len_k must match between key and value.
d_v (value dimension) can differ from d_k (key/query dimension).
mask: Optional mask tensor; shape broadcastable to (batch_size, seq_len_q, seq_len_k).
Positions with True or 1 indicate values to keep, False or 0 to mask out.
Returns:
output: The attention output tensor; shape (batch_size, seq_len_q, d_v)
attention_weights: The attention weights; shape (batch_size, seq_len_q, seq_len_k)
"""
# Dimension of keys
d_k = key.size(-1)
# 1. Calculate QK^T
# (batch_size, seq_len_q, d_k) @ (batch_size, d_k, seq_len_k) -> (batch_size, seq_len_q, seq_len_k)
scores = torch.matmul(query, key.transpose(-2, -1))
# 2. Scale by sqrt(d_k)
scores = scores / math.sqrt(d_k)
# 3. Apply mask (if provided)
if mask is not None:
# Mask values are typically False for positions to mask out
# We need to set masked positions to a large negative value (-inf) before softmax
# Use fill_value=-1e9 or similar large negative number for numerical stability
scores = scores.masked_fill(mask == 0, -1e9) # PyTorch convention: 0/False means mask
# 4. Apply softmax to get attention weights
# Softmax is applied on the last dimension (seq_len_k)
attention_weights = F.softmax(scores, dim=-1)
# Handle potential NaN from softmax if all scores in a row are -inf
# This can happen if a query position is masked entirely against all keys.
# Replace NaN with 0.
attention_weights = torch.nan_to_num(attention_weights)
# 5. Multiply weights by V
# (batch_size, seq_len_q, seq_len_k) @ (batch_size, seq_len_k, d_v) -> (batch_size, seq_len_q, d_v)
output = torch.matmul(attention_weights, value)
return output, attention_weights
Let's create some sample tensors and test our function. We'll use a small batch size, sequence length, and dimensions for illustration.
# Example Parameters
batch_size = 1
seq_len_q = 3 # Sequence length for queries
seq_len_k = 4 # Sequence length for keys/values
d_k = 8 # Dimension of keys/queries
d_v = 16 # Dimension of values
# Generate random tensors (replace with actual embeddings in practice)
query = torch.randn(batch_size, seq_len_q, d_k)
key = torch.randn(batch_size, seq_len_k, d_k)
value = torch.randn(batch_size, seq_len_k, d_v)
# --- Without Mask ---
output, attention_weights = scaled_dot_product_attention(query, key, value)
print("--- Output without Mask ---")
print("Output shape:", output.shape) # Expected: [1, 3, 16]
print("Attention weights shape:", attention_weights.shape) # Expected: [1, 3, 4]
# Each row in attention_weights should sum to 1
print("Attention weights sum (first query):", attention_weights[0, 0, :].sum())
# --- With Mask ---
# Create a sample mask. Let's mask out the last key position for all queries.
# Mask shape: (batch_size, seq_len_q, seq_len_k)
# Here, simpler: (batch_size, 1, seq_len_k) broadcastable
mask = torch.ones(batch_size, 1, seq_len_k, dtype=torch.bool)
mask[:, :, -1] = 0 # Mask the last key position (index 3)
print("\nMask shape:", mask.shape)
print("Mask content:\n", mask)
output_masked, attention_weights_masked = scaled_dot_product_attention(query, key, value, mask=mask)
print("\n--- Output with Mask ---")
print("Output shape:", output_masked.shape) # Expected: [1, 3, 16]
print("Attention weights shape:", attention_weights_masked.shape) # Expected: [1, 3, 4]
print("Masked Attention weights (first query):\n", attention_weights_masked[0, 0, :])
# Note that the weight for the last key position (index 3) should be 0 or very close to it.
print("Attention weights sum (first query, masked):", attention_weights_masked[0, 0, :].sum())
You should observe that the output shapes match our expectations. When the mask is applied, the attention weight corresponding to the masked key position (the last one in this case) becomes zero, and the remaining weights are renormalized via softmax to sum to 1.
Visualizing the attention weights matrix (softmax(dkQKT)) can provide insights into how the model relates different parts of the sequence. A heatmap is often used for this purpose. Let's visualize the weights from our unmasked example.
# Use the attention_weights from the unmasked example
weights_np = attention_weights[0].detach().numpy() # Get weights for the first batch item
# Create heatmap data
fig_data = go.Heatmap(
z=weights_np,
x=[f'Key Pos {i}' for i in range(seq_len_k)],
y=[f'Query Pos {i}' for i in range(seq_len_q)],
colorscale='Blues', # Use blue color scale
colorbar=dict(title='Attention Weight')
)
# Create layout
fig_layout = go.Layout(
title='Attention Weights (Query vs Key)',
xaxis_title="Key Sequence Position",
yaxis_title="Query Sequence Position",
yaxis_autorange='reversed', # Show query 0 at the top
width=500, height=400, margin=dict(l=50, r=50, b=100, t=100, pad=4)
)
# Generate figure object as JSON for web display
fig = go.Figure(data=[fig_data], layout=fig_layout)
plotly_json = fig.to_json()
Heatmap showing attention weights for each query position (rows) attending to each key position (columns). Higher values (darker blue) indicate stronger attention. Each row sums to 1.
This visualization shows, for each query position (row), how much "attention" or weight it gives to each key position (column) when computing its output. In a real application with meaningful data, patterns in this matrix can reveal syntactic or semantic relationships learned by the model. For instance, a verb might attend strongly to its subject and object.
This implementation of Scaled Dot-Product Attention forms the core computational unit within each attention head in a Transformer. In the next chapter, we will build upon this foundation to implement Multi-Head Attention, allowing the model to jointly attend to information from different representational subspaces.
© 2025 ApX Machine Learning