Masterclass
Autoregressive generation, the process of producing text one token at a time based on previously generated tokens, forms the core of how large language models generate responses. However, a naive implementation faces a significant performance hurdle. Recall the self-attention mechanism within the Transformer architecture (Chapter 4). To generate the next token, say token t+1, the standard self-attention calculation involves computing Queries (Q), Keys (K), and Values (V) based on all preceding tokens 1...t, and then computing attention scores. When generating the subsequent token t+2, this entire process is repeated using tokens 1...t+1. Notice the redundancy: the Key and Value vectors calculated for tokens 1...t during the generation of token t+1 are identical to those needed for the first t tokens when generating token t+2. Repeating these calculations at every step is computationally wasteful, especially as the sequence length grows.
Key-Value (KV) caching is a fundamental optimization technique designed specifically to eliminate this redundancy in autoregressive inference. The core idea is simple yet effective: store the computed Key (K) and Value (V) tensors from the self-attention layers for all previous tokens and reuse them in subsequent generation steps.
In a Transformer's self-attention layer, the input sequence X is projected into three matrices: Query (Q), Key (K), and Value (V).
Q=XWQ​,K=XWK​,V=XWV​Where WQ​,WK​,WV​ are learnable weight matrices. The attention output is then computed as:
Attention(Q,K,V)=softmax(dk​​QKT​)VConsider generating token t+1. The model takes the sequence of token embeddings x1​,...,xt​ as input. Within each attention layer, it computes K1​,...,Kt​ and V1​,...,Vt​. It also computes the Query vector Qt+1​ based on the embedding of the last token xt​ (or a positional embedding corresponding to position t+1). The attention calculation then uses Qt+1​ with the full set of Keys K=[K1​,...,Kt​] and Values V=[V1​,...,Vt​].
Now, consider generating token t+2. The input sequence is x1​,...,xt+1​. The model needs to compute K1​,...,Kt+1​ and V1​,...,Vt+1​. Crucially, the computations for K1​,...,Kt​ and V1​,...,Vt​ are exactly the same as in the previous step because they only depend on the input tokens x1​,...,xt​ and the fixed weight matrices WK​ and WV​.
KV caching exploits this. Instead of recomputing all Keys and Values at each step:
Simplified flow showing how Keys (K) and Values (V) computed at step
t
are cached and reused at stept+1
, requiring only the computation for the new tokenx_{t+1}
.
This drastically reduces the computational cost per generated token. Instead of the attention computation complexity being roughly proportional to the square of the sequence length t at each step (O(t2) if considering the full matrix multiply, or O(t) just for applying the query to existing keys), the computation related to past tokens is effectively constant time (cache lookup and concatenation), and the main cost becomes computing K and V for the single new token and applying the new Query to the cached Keys (O(t) for the QKT part).
While KV caching significantly speeds up inference, it introduces a memory cost. The cache needs to store the Key and Value tensors for all preceding tokens, across all layers and all attention heads, for every sequence in the batch. The size of the KV cache can be estimated as:
Cache Size ≈ batch_size × num_layers × 2 (for K and V) × num_heads × sequence_length × head_dimension × bytes_per_element
This memory footprint grows linearly with the sequence_length
. For models with many layers and heads, and when processing long sequences or large batches, the KV cache can consume a substantial amount of GPU memory, sometimes becoming the limiting factor for the maximum sequence length that can be handled. Managing this memory usage is an important consideration, leading to techniques like paged attention or quantization of the cache itself, although those are beyond the scope of this basic introduction.
Implementing KV caching typically involves modifying the forward
method of the Transformer blocks (or the attention modules directly) to accept an optional past_key_values
argument and return updated present_key_values
.
Here's a sketch (highly simplified) comparing standard attention computation with one using a KV cache:
import torch
import torch.nn as nn
# Assume 'attention_layer' is a pre-defined multi-head attention module
# Simplified MultiHeadAttention placeholder
class SimpleMultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, (
"embed_dim must be divisible by num_heads"
)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, past_kv=None):
batch_size, seq_len, _ = query.size()
# Project query, key, value
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
# Reshape for multi-head attention
q = q.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_q, hd]
k = k.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_k, hd]
v = v.view(
batch_size, -1, self.num_heads, self.head_dim
).transpose(1, 2) # [B, nh, L_v, hd]
present_kv = None
if past_kv is not None:
# past_kv is a tuple (past_k, past_v)
# Each has shape [B, nh, L_past, hd]
past_k, past_v = past_kv
# Concatenate along the sequence length dimension (dim=2)
k = torch.cat((past_k, k), dim=2)
v = torch.cat((past_v, v), dim=2)
# Store the updated K, V for the next step
present_kv = (k, v) # Shape [B, nh, L_past + L_k, hd]
else:
# Store K, V for the first time
present_kv = (k, v) # Shape [B, nh, L_k, hd]
# Compute attention scores
# q: [B, nh, L_q, hd], k.transpose: [B, nh, hd, L_k]
# -> scores: [B, nh, L_q, L_k]
scores = torch.matmul(q, k.transpose(-2, -1))
scores = scores / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
# Apply attention weights to values
# attn_weights: [B, nh, L_q, L_k], v: [B, nh, L_v, hd]
# -> output: [B, nh, L_q, hd]
# (L_v == L_k assumed here)
output = torch.matmul(attn_weights, v)
# Reshape and project output
output = output.transpose(1, 2).contiguous()
output = output.view(
batch_size, -1, self.embed_dim
) # [B, L_q, embed_dim]
output = self.out_proj(output)
# Return output and the updated key-value cache for this layer
return output, present_kv
# --- Usage during generation ---
# model = YourTransformerModel(...)
# kv_cache = None # Initially empty cache (list or tuple per layer)
# input_ids = initial_prompt_ids
# for _ in range(max_new_tokens):
# # Prepare input for the current step
# # (usually just the last generated token)
# current_input_ids = input_ids[:, -1:] # Shape [B, 1]
# # Forward pass with cache
# # Note: The model's forward needs to handle passing cache down
# # and collecting updates
# outputs = model(
# input_ids=current_input_ids,
# past_key_values=kv_cache,
# use_cache=True
# )
# logits = outputs.logits
# kv_cache = outputs.past_key_values # Update cache for the next iteration
# # Get the predicted next token ID (e.g., using argmax or sampling)
# next_token_id = torch.argmax(
# logits[:, -1:, :], dim=-1
# ) # Shape [B, 1]
# # Append the new token ID for the next iteration's full input
# # (though only last is used for Q)
# input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# # Check stopping conditions, etc.
In practice, frameworks like Hugging Face Transformers abstract this caching mechanism. When calling the generate
method or using the model's forward pass with the use_cache=True
argument, the framework automatically handles the creation, passing, and updating of the KV cache between generation steps. However, understanding the underlying principle is important for appreciating the performance gains and memory implications.
KV caching is a cornerstone of efficient Transformer inference. It directly addresses the quadratic complexity bottleneck of naive autoregressive decoding concerning sequence length, making the generation of longer sequences feasible in practice. While it introduces memory overhead, the computational savings almost always make it an indispensable optimization.
© 2025 ApX Machine Learning