Understanding how a trained Transformer makes predictions is as significant as optimizing its training speed or memory usage. Analyzing attention weight distributions provides a window into the model's internal reasoning process. It allows us to observe which parts of the input sequence the model focuses on when processing a specific token. This practice is invaluable for debugging, verifying model behavior, and gaining insights into whether the model has learned meaningful relationships or is relying on spurious correlations.
Most modern deep learning frameworks provide mechanisms to access intermediate activations within a model, including attention weights. The exact method depends on the framework and how the Transformer model is implemented. Common approaches include:
forward
pass: Alter the model's forward
method (or equivalent) to return attention weights alongside the main output. This is often straightforward if you have control over the model's source code.Attention
layer). These hooks can capture the outputs (or inputs) of a module during the forward or backward pass without permanently altering the model code.output_attentions=True
) that instructs the model to return attention weights.Regardless of the method, the goal is to retrieve the attention probability matrices, typically computed after the softmax operation within each attention head. For a standard multi-head attention layer, the output attention weights often have a shape like [batch_size, num_heads, sequence_length_query, sequence_length_key]
. For self-attention, sequence_length_query
and sequence_length_key
are the same (the input sequence length). For cross-attention in the decoder, sequence_length_key
corresponds to the encoder output sequence length.
Let's assume you have obtained the attention weights for a specific layer and head for a single example from your batch. The resulting tensor, let's call it attention_probs
, might have the shape [num_heads, seq_len, seq_len]
. We can select a specific head for visualization, resulting in a 2D matrix of shape [seq_len, seq_len]
.
The most common way to visualize these 2D attention matrices is using heatmaps. Each cell (i,j) in the heatmap represents the attention weight from the i-th query token to the j-th key token. Higher values (brighter colors) indicate stronger attention.
Consider a simple input sentence: "The quick brown fox jumps". Let's visualize the self-attention weights from the first layer, first head.
# Conceptual Python code using Matplotlib/Seaborn
# Assumes 'attention_matrix' is a [seq_len, seq_len] numpy array
# Assumes 'tokens' is a list of strings: ["The", "quick", "brown", "fox", "jumps"]
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
# Example attention data (replace with actual extracted weights)
# Rows: Query tokens (attending FROM)
# Columns: Key tokens (attending TO)
attention_matrix = np.random.rand(5, 5)
# Make diagonal slightly stronger for realism
np.fill_diagonal(attention_matrix, attention_matrix.diagonal() + 0.3)
# Normalize rows to sum to 1 (like softmax output)
attention_matrix /= attention_matrix.sum(axis=1, keepdims=True)
tokens = ["The", "quick", "brown", "fox", "jumps"]
seq_len = len(tokens)
plt.figure(figsize=(7, 6))
sns.heatmap(attention_matrix, xticklabels=tokens, yticklabels=tokens, cmap="viridis", annot=True, fmt=".2f")
plt.xlabel("Key Tokens (Attending To)")
plt.ylabel("Query Tokens (Attending From)")
plt.title("Self-Attention Weights (Layer 1, Head 1)")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()
Here is an example using Plotly for web visualization, showing hypothetical attention weights for the same sentence.
Heatmap showing attention scores between tokens. Rows represent the token generating the query, and columns represent the tokens generating keys/values. Brighter cells indicate higher attention scores.
When analyzing these visualizations, look for recurring patterns:
[CLS]
, [SEP]
, [BOS]
, [EOS]
. Sometimes, the [CLS]
token (in BERT-like models) aggregates information from the entire sequence, showing broad attention. In decoders, tokens often attend to the [EOS]
(end-of-sequence) token of the previous sentence or segment.While insightful, attention visualization is not a definitive explanation of model behavior.
Analyzing attention weights is a practical technique for gaining qualitative insights into your Transformer model. It complements quantitative evaluation metrics by helping you understand how the model processes information, which is essential for building more effective and reliable systems. This practice helps confirm whether complex mechanisms like learned positional embeddings, layer normalization strategies, or specific optimizers are leading to sensible internal representations and information flow.
© 2025 ApX Machine Learning