Masterclass
One of the most direct ways to gain some insight into the inner workings of a Transformer model is by examining its attention mechanisms. Recall from Chapter 4, "The Transformer Architecture", that self-attention allows the model to weigh the importance of different tokens in the input sequence when computing the representation for a specific token. These attention weights, calculated for each head in each layer, form maps that show how information is routed through the model. Visualizing these maps can provide clues about the relationships the model has learned between tokens.
The core idea of self-attention involves computing scores between a token's query vector (Q) and the key vectors (K) of all tokens in the sequence (including itself). These scores are scaled, normalized using softmax, and then used to compute a weighted sum of the value vectors (V). The attention weights are the result of the scaled dot-product followed by the softmax:
Weights=softmax(dk​​QKT​)Here, dk​ is the dimension of the key vectors. These weights represent the distribution of attention from each query token to all key tokens. A higher weight indicates that the model considers the corresponding key token more significant when generating the representation for the query token.
Most modern deep learning frameworks, including PyTorch, provide mechanisms to access these attention weights during a forward pass. When using PyTorch's nn.MultiheadAttention
layer, you can specify need_weights=True
during the forward call. This argument instructs the layer to return the average attention weights across all heads, in addition to the layer's output. For more granular, head-specific weights, you might need to slightly modify the layer's implementation or use hooks to capture the weights before they are averaged.
Here's a simplified example illustrating how to get attention weights from a nn.MultiheadAttention
layer in PyTorch:
import torch
import torch.nn as nn
# Example setup
seq_len = 5
embed_dim = 8
num_heads = 2
batch_size = 1
# Ensure embed_dim is divisible by num_heads
assert embed_dim % num_heads == 0
mha_layer = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
# Dummy input (batch_size, seq_len, embed_dim)
query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)
# Forward pass requesting attention weights
# attn_output: (batch_size, seq_len, embed_dim)
# attn_output_weights: (batch_size, seq_len, seq_len) -> Average over heads
attn_output, attn_output_weights = mha_layer(
query, key, value,
need_weights=True,
average_attn_weights=True
) # Set False for per-head weights (if layer supports/modified)
print("Shape of averaged attention weights:", attn_output_weights.shape)
# If average_attn_weights=False (and layer modified/hooked):
# Shape would be (batch_size, num_heads, seq_len, seq_len)
# Example: Access weights for the first batch item
first_batch_weights = attn_output_weights[0] # Shape: (seq_len, seq_len)
# first_batch_weights[i, j] is the attention from query token i to key token j
# To get per-head weights (requires modification or hooks usually)
# _, attn_output_weights_per_head = mha_layer(
# query, key, value,
# need_weights=True,
# average_attn_weights=False
# )
# print("Shape of per-head attention weights:",
# attn_output_weights_per_head.shape)
# first_batch_head_0_weights = attn_output_weights_per_head[0, 0] # Head 0
Note that the standard nn.MultiheadAttention
returns weights averaged across heads if average_attn_weights
is True
(the default if need_weights
is True
). Accessing individual head weights typically requires either modifying the forward method or, more cleanly, registering a forward hook on the attention mechanism's internal softmax or matrix multiplication operations to capture the weights before averaging.
Once extracted, attention weights, typically matrices of size (sequence_length, sequence_length)
for each head/layer, can be visualized in several ways:
(i, j)
indicates the attention weight from token i
to token j
. Lighter colors often signify higher attention. Analyzing these heatmaps can reveal patterns, such as strong diagonal lines (tokens attending to themselves), attention to preceding tokens, or specific tokens (like punctuation or special tokens) acting as information sinks or sources.Hypothetical attention weights for a single head. Notice the strong diagonal indicating self-attention, and how "sat" attends strongly to "cat". The special token
[CLS]
attends mostly to itself, while[SEP]
also shows high self-attention.
Multi-Head Visualization: Since each layer contains multiple attention heads, visualizing them all is important. Common techniques include:
Graph-Based Visualization: Attention weights can be represented as a directed graph where tokens are nodes and a directed edge from token i
to token j
exists if the attention weight wij​ exceeds a certain threshold. Edge thickness or color can represent the weight's magnitude. This can be effective for visualizing connections in shorter sequences or highlighting specific strong relationships.
Simplified graph showing strong attention links (hypothetical). "sat" strongly attends to "cat", while "cat" attends significantly to "the".
Analyzing attention patterns can sometimes reveal linguistically plausible behaviors:
[CLS]
or [SEP]
might aggregate information from the entire sequence, indicated by broad attention patterns originating from or targeting them.While attention visualization is a valuable tool, it's important to be aware of its limitations:
Attention map visualization provides a window, albeit a foggy one sometimes, into the flow of information within a Transformer. It's a useful diagnostic technique for hypothesis generation about model behavior and identifying potential areas of interest, but conclusions should be drawn cautiously and ideally corroborated with other analysis methods discussed later in this chapter, such as probing internal representations or analyzing neuron activations.
© 2025 ApX Machine Learning