Attention mechanisms have become fundamental components in many state-of-the-art deep learning models, particularly in sequence modeling tasks handled by architectures like the Transformer. At its core, attention allows a model to dynamically focus on different parts of an input sequence when producing an output, assigning varying weights or "attention" to different input elements based on their relevance to the current processing step. This contrasts with earlier sequence models like RNNs, which typically process information sequentially and can struggle with long-range dependencies due to information bottlenecks.
Instead of relying solely on the final hidden state of an encoder, attention mechanisms compute a context vector as a weighted sum of input representations. The weights are determined dynamically based on the relationship between the current output step (represented by a query) and the input elements (represented by keys).
Let's implement the most common types of attention mechanisms from the ground up using TensorFlow.
This is arguably the most widely used attention mechanism, forming the core of the Transformer architecture. It operates on three inputs: Queries (Q), Keys (K), and Values (V). Think of it like a database retrieval system: for a given query (Q), we compute a score against each available key (K). These scores determine how much attention (weight) we pay to the corresponding values (V). The output is a weighted sum of the values.
The formula is:
Attention(Q,K,V)=softmax(dkQKT)VWhere:
Let's implement this using TensorFlow. We'll create a function that takes Q, K, V, and an optional mask as input. The mask is important for preventing attention to certain positions, such as padding tokens in sequences or future tokens in causal attention (used in decoders).
import tensorflow as tf
def scaled_dot_product_attention(q, k, v, mask=None):
"""Calculates the attention weights and output.
Args:
q: Query tensor; shape == (..., seq_len_q, depth)
k: Key tensor; shape == (..., seq_len_k, depth)
v: Value tensor; shape == (..., seq_len_v, depth_v)
Note: seq_len_k == seq_len_v
mask: Optional Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output: Weighted sum of values, shape == (..., seq_len_q, depth_v)
attention_weights: Attention scores after softmax,
shape == (..., seq_len_q, seq_len_k)
"""
# Calculate the dot product between queries and keys.
# q shape: (..., seq_len_q, depth)
# k shape: (..., seq_len_k, depth)
# matmul_qk shape: (..., seq_len_q, seq_len_k)
matmul_qk = tf.matmul(q, k, transpose_b=True)
# Scale matmul_qk by the square root of the depth (dk)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# Add the mask to the scaled tensor.
# The mask adds large negative values to positions that should not be attended.
if mask is not None:
# The mask needs to be broadcastable to the shape of scaled_attention_logits
# Typical mask shape is (batch_size, 1, 1, seq_len_k) or (batch_size, 1, seq_len_q, seq_len_k)
scaled_attention_logits += (mask * -1e9) # Adding a large negative number
# Apply softmax to get attention weights.
# attention_weights shape: (..., seq_len_q, seq_len_k)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# Multiply weights by values to get the weighted sum.
# attention_weights shape: (..., seq_len_q, seq_len_k)
# v shape: (..., seq_len_v, depth_v) where seq_len_v == seq_len_k
# output shape: (..., seq_len_q, depth_v)
output = tf.matmul(attention_weights, v)
return output, attention_weights
# Example Usage (Illustrative Shapes)
batch_size = 4
seq_len_q = 10 # Target sequence length
seq_len_k = 12 # Source sequence length
depth = 64 # Key/Query dimension
depth_v = 128 # Value dimension
# Random example tensors
queries = tf.random.normal([batch_size, seq_len_q, depth])
keys = tf.random.normal([batch_size, seq_len_k, depth])
values = tf.random.normal([batch_size, seq_len_k, depth_v]) # seq_len_k matches keys
# Optional mask (e.g., masking padding tokens in keys)
# Mask shape (batch_size, 1, 1, seq_len_k) to broadcast correctly
# 1 indicates valid token, 0 indicates padding/masked token
# We need 0 for valid, 1 for masked for the addition trick
mask_values = tf.cast(tf.random.uniform((batch_size, 1, 1, seq_len_k)) > 0.5, tf.float32) # Example mask
# Calculate attention
attention_output, attention_weights = scaled_dot_product_attention(queries, keys, values, mask=mask_values)
print("Attention Output Shape:", attention_output.shape)
print("Attention Weights Shape:", attention_weights.shape)
This implementation captures the essence of scaled dot-product attention. Notice how tensor shapes must align correctly for the matrix multiplications. The optional mask allows flexible control over which key/value pairs contribute to the output for each query.
Another influential attention mechanism, often called Bahdanau attention, was introduced slightly earlier than dot-product attention. Instead of using a simple dot product, it uses a feed-forward network with a single hidden layer to calculate the alignment score between the query and keys.
The score (or energy) between a query q and a key ki is often computed as:
score(q,ki)=vTtanh(Wqq+Wkki+b)Where Wq, Wk, and v are learned weight matrices/vectors, and b is a bias term. The scores for all keys are then passed through a softmax function to obtain the attention weights.
Additive attention can handle queries and keys of different dimensions because the projections Wqq and Wkki can map them to a common dimension before the tanh activation.
Here's how you might implement it as a Keras layer:
import tensorflow as tf
from tensorflow.keras import layers
class AdditiveAttention(layers.Layer):
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
# Dense layers to project query and key to the same dimension (units)
self.Wq = layers.Dense(units, use_bias=False, name='query_projection')
self.Wk = layers.Dense(units, use_bias=False, name='key_projection')
# Dense layer to compute the score from the combined projection
self.V = layers.Dense(1, name='score_projection') # Projects down to a single score
def call(self, query, key, value, mask=None):
"""
Args:
query: Query tensor; shape == (batch_size, seq_len_q, query_depth)
or (batch_size, query_depth) if single query
key: Key tensor; shape == (batch_size, seq_len_k, key_depth)
value: Value tensor; shape == (batch_size, seq_len_k, value_depth)
mask: Optional mask tensor; shape broadcastable to
(batch_size, seq_len_q, seq_len_k).
Returns:
context_vector: Weighted sum of values; shape == (batch_size, seq_len_q, value_depth)
or (batch_size, value_depth) if single query
attention_weights: Attention scores after softmax;
shape == (batch_size, seq_len_q, seq_len_k)
or (batch_size, 1, seq_len_k) if single query
"""
# Add time dimension to query if it's a single vector (e.g., decoder state)
# query shape: (batch_size, 1, query_depth)
if tf.rank(query) == 2:
query = tf.expand_dims(query, 1) # Add seq_len_q dimension
# Project query and key
# query_proj shape: (batch_size, seq_len_q, units)
# key_proj shape: (batch_size, seq_len_k, units)
query_proj = self.Wq(query)
key_proj = self.Wk(key)
# Expand dims for broadcasting addition
# query_proj shape: (batch_size, seq_len_q, 1, units)
# key_proj shape: (batch_size, 1, seq_len_k, units)
query_proj_expanded = tf.expand_dims(query_proj, 2)
key_proj_expanded = tf.expand_dims(key_proj, 1)
# Calculate alignment scores (energy)
# combined_proj shape: (batch_size, seq_len_q, seq_len_k, units)
# scores shape: (batch_size, seq_len_q, seq_len_k, 1) -> (batch_size, seq_len_q, seq_len_k)
scores = self.V(tf.nn.tanh(query_proj_expanded + key_proj_expanded))
scores = tf.squeeze(scores, axis=-1) # Remove last dimension
# Apply mask before softmax
if mask is not None:
# Mask shape needs adjustment if query had rank 2 initially
if tf.rank(mask) == 2: # e.g., shape (batch_size, seq_len_k)
mask = tf.expand_dims(mask, 1) # -> (batch_size, 1, seq_len_k)
scores += (tf.cast(mask, tf.float32) * -1e9)
# Compute attention weights using softmax
# attention_weights shape: (batch_size, seq_len_q, seq_len_k)
attention_weights = tf.nn.softmax(scores, axis=-1)
# Compute the context vector (weighted sum of values)
# attention_weights shape: (batch_size, seq_len_q, seq_len_k)
# value shape: (batch_size, seq_len_k, value_depth)
# context_vector shape: (batch_size, seq_len_q, value_depth)
context_vector = tf.matmul(attention_weights, value)
# If original query was rank 2, remove the added dimension
if tf.rank(query) == 2:
context_vector = tf.squeeze(context_vector, axis=1)
# Keep attention_weights shape as (batch_size, 1, seq_len_k) for clarity if needed
# attention_weights shape remains (batch_size, seq_len_q, seq_len_k) which is (batch_size, 1, seq_len_k)
return context_vector, attention_weights
# Example Usage
batch_size = 4
seq_len_q = 1 # Example: Single query like decoder state
seq_len_k = 12 # Source sequence length
query_depth = 50
key_depth = 60
value_depth = 70
units = 32 # Hidden units in the additive attention mechanism
# Example tensors
single_query = tf.random.normal([batch_size, query_depth])
keys = tf.random.normal([batch_size, seq_len_k, key_depth])
values = tf.random.normal([batch_size, seq_len_k, value_depth])
# Optional mask (1 for masked, 0 for valid) -> Cast to bool for clarity
mask = tf.cast(tf.random.uniform((batch_size, seq_len_k)) > 0.8, tf.bool) # Mask ~20%
additive_attention_layer = AdditiveAttention(units)
context, weights = additive_attention_layer(single_query, keys, values, mask=mask)
print("Additive Attention Context Shape:", context.shape)
print("Additive Attention Weights Shape:", weights.shape)
# Example with sequence query
seq_query = tf.random.normal([batch_size, 10, query_depth]) # seq_len_q = 10
context_seq, weights_seq = additive_attention_layer(seq_query, keys, values, mask=None) # No mask for simplicity
print("\nAdditive Attention Context Shape (Sequence Query):", context_seq.shape)
print("Additive Attention Weights Shape (Sequence Query):", weights_seq.shape)
This Keras layer encapsulates the logic, making it reusable. The use of Dense
layers makes learning the projection weights (Wq, Wk, V) straightforward during model training.
Multi-Head Attention (MHA) enhances the basic attention mechanism by performing multiple attention calculations in parallel, each focusing on different aspects or "representation subspaces" of the input.
The process involves:
Flow of Multi-Head Attention. Inputs are projected multiple times, attention is computed in parallel for each head, outputs are concatenated, and a final linear transformation is applied.
This allows the model to capture different types of relationships at different positions simultaneously. Let's implement this as a Keras layer.
import tensorflow as tf
from tensorflow.keras import layers
class MultiHeadAttention(layers.Layer):
def __init__(self, d_model, num_heads, **kwargs):
"""
Args:
d_model: Total dimension of the model (output dimension).
Must be divisible by num_heads.
num_heads: Number of attention heads.
"""
super().__init__(**kwargs)
self.num_heads = num_heads
self.d_model = d_model
# Ensure d_model is divisible by num_heads
assert d_model % self.num_heads == 0
# Depth of each attention head's projection
self.depth = d_model // self.num_heads
# Dense layers for linear projections of Q, K, V
self.wq = layers.Dense(d_model, name='query_projection') # Projects to d_model
self.wk = layers.Dense(d_model, name='key_projection') # Projects to d_model
self.wv = layers.Dense(d_model, name='value_projection') # Projects to d_model
# Dense layer for the final linear projection
self.dense = layers.Dense(d_model, name='output_projection')
def split_heads(self, x, batch_size):
"""Splits the last dimension into (num_heads, depth).
Transposes the result such that shape is (batch_size, num_heads, seq_len, depth)
Args:
x: Input tensor; shape == (batch_size, seq_len, d_model)
batch_size: Batch size.
Returns:
Tensor with shape (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask=None):
"""
Args:
v: Value tensor; shape == (batch_size, seq_len_v, d_model)
k: Key tensor; shape == (batch_size, seq_len_k, d_model)
q: Query tensor; shape == (batch_size, seq_len_q, d_model)
mask: Optional mask.
Returns:
output: Final attention output; shape == (batch_size, seq_len_q, d_model)
attention_weights: Attention weights; shape ==
(batch_size, num_heads, seq_len_q, seq_len_k)
"""
batch_size = tf.shape(q)[0]
# 1. Linear Projections
# q, k, v shape: (batch_size, seq_len, d_model)
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
# 2. Split into multiple heads
# q, k, v shape: (batch_size, num_heads, seq_len, depth)
q = self.split_heads(q, batch_size)
k = self.split_heads(k, batch_size)
v = self.split_heads(v, batch_size)
# 3. Apply Scaled Dot-Product Attention per head
# scaled_attention shape: (batch_size, num_heads, seq_len_q, depth)
# attention_weights shape: (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = scaled_dot_product_attention(
q, k, v, mask)
# 4. Concatenate heads
# Transpose back: (batch_size, seq_len_q, num_heads, depth)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])
# Concatenate: (batch_size, seq_len_q, d_model)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model))
# 5. Final Linear Projection
# output shape: (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention)
return output, attention_weights
# Example Usage
batch_size = 4
seq_len = 15 # Assuming Q, K, V sequences have same length for simplicity here
d_model = 128 # Model dimension
num_heads = 8
# Example inputs (often Q, K, V come from the same source in self-attention)
input_seq = tf.random.normal([batch_size, seq_len, d_model])
mha_layer = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
# Self-Attention example: Q, K, V are the same
output, weights = mha_layer(v=input_seq, k=input_seq, q=input_seq, mask=None)
print("Multi-Head Attention Output Shape:", output.shape)
print("Multi-Head Attention Weights Shape:", weights.shape) # Note the num_heads dimension
This MultiHeadAttention
layer is a fundamental building block for Transformer encoders and decoders. Note the split_heads
and concatenation steps, which are essential for managing the parallel computations across heads.
Self-attention is a specific application of an attention mechanism (like Scaled Dot-Product or Multi-Head Attention) where the queries, keys, and values are all derived from the same input sequence. This allows the model to weigh the importance of different words or tokens within the same sequence when computing the representation for each token.
For instance, in the sentence "The animal didn't cross the street because it was too tired," self-attention can help the model learn that "it" refers to "the animal" rather than "the street."
To implement self-attention, you simply pass the same input tensor (or tensors derived from it through initial projections) as the query, key, and value arguments to your attention layer (like MultiHeadAttention
shown above). The example usage for MultiHeadAttention
already demonstrated this case.
These building blocks Scaled Dot-Product Attention, Additive Attention, and Multi-Head Attention provide powerful and flexible ways to incorporate context and focus into deep learning models. Understanding how to implement them from scratch is valuable for customizing architectures and interpreting model behavior, forming the foundation for implementing advanced models like Transformers.
© 2025 ApX Machine Learning