Masterclass
While techniques like KV caching optimize the computation within a single forward pass for autoregressive generation, the fundamental limitation remains: generating N tokens requires N sequential forward passes through the large language model. This sequential dependency inherently limits the maximum achievable speed. Speculative decoding offers a clever way to parallelize parts of this process, aiming to generate multiple tokens per single forward pass of the main, large model, thereby reducing overall wall-clock latency.
The core idea relies on using two models:
Instead of the target model generating one token at a time, the process works as follows:
The potential speedup comes from the fact that if n>0 tokens are accepted in a cycle, we have effectively generated n+1 tokens (the n accepted ones plus the final sampled one) using only one forward pass of the expensive target model and k fast forward passes of the draft model. If the draft model is accurate enough, the acceptance rate (n approaching k) can be high, leading to significant reductions in generation time. Importantly, the statistical acceptance mechanism ensures that the final generated sequence follows the exact probability distribution of the target model.
Flowchart illustrating the speculative decoding process. The draft model proposes tokens, the target model verifies them in a single pass, and an acceptance loop determines how many proposed tokens are kept before sampling the next token.
Here's a PyTorch-like snippet illustrating the core loop structure:
import torch
import torch.nn.functional as F
def speculative_decode_step(target_model, draft_model, input_ids, k):
"""
Performs one step of speculative decoding.
Assumes models return logits and handle KV caching internally.
This is a simplified illustration.
"""
# 1. Draft Generation
# (using draft model's autoregressive generation)
draft_output_ids = draft_model.generate(input_ids, max_new_tokens=k, ...)
# Get only the k new tokens
draft_ids = draft_output_ids[:, input_ids.shape[-1]:]
# Combine original input with draft tokens for verification
verify_ids = torch.cat([input_ids, draft_ids], dim=-1)
# 2. Target Verification (single forward pass)
# target_logits shape: [batch_size, verify_seq_len, vocab_size]
with torch.no_grad(): # Ensure no gradients are computed
target_logits = target_model(verify_ids).logits
# Extract target probabilities for the drafted positions
# We look at logits for predicting draft_ids[j] given the
# preceding tokens
# target_probs shape: [batch_size, k, vocab_size]
target_probs = F.softmax(
target_logits[:, input_ids.shape[-1]-1:-1, :],
dim=-1
)
# Also get draft model probabilities for the drafted tokens
# (might need separate call or be part of draft_model.generate)
# Assume draft_probs has shape [batch_size, k, vocab_size]
# draft_probs = get_draft_probs(
# draft_model, input_ids, draft_ids
# ) # Placeholder function
accepted_count = 0
for j in range(k):
# Get probabilities for the specific token that *was* drafted at step j
# Shape [batch_size, 1]
p_target = target_probs[:, j, draft_ids[:, j]].unsqueeze(-1)
# Shape [batch_size, 1]
p_draft = draft_probs[:, j, draft_ids[:, j]].unsqueeze(-1)
# Add epsilon for numerical stability
ratio = p_target / (p_draft + 1e-8)
# Shape [batch_size, 1]
random_uniform = torch.rand_like(ratio)
# Check if accepted for all items in batch
if (ratio >= random_uniform).all():
accepted_count += 1
else:
# Rejection occurred
# Sample the (accepted_count + 1)-th token based on
# modified distribution
# p_modified = (target_probs[:, j, :]
# - random_uniform * draft_probs[:, j, :]).clamp(min=0)
# p_modified /= p_modified.sum(dim=-1, keepdim=True)
# next_token = torch.multinomial(p_modified, num_samples=1)
# final_ids = torch.cat([
# input_ids,
# draft_ids[:, :accepted_count],
# next_token
# ], dim=-1)
# return final_ids
break # Simplified: Stop accepting
if accepted_count == k:
# All k accepted, sample the (k+1)-th token from the target
# model's last distribution
next_token_probs = F.softmax(target_logits[:, -1, :], dim=-1)
next_token = torch.multinomial(next_token_probs, num_samples=1)
final_ids = torch.cat([input_ids, draft_ids, next_token], dim=-1)
else:
# Rejection occurred at accepted_count + 1
# Simplified: Just return the accepted prefix for illustration
# A real implementation would sample the corrected token here
final_ids = torch.cat(
[input_ids, draft_ids[:, :accepted_count]], dim=-1
)
# Need to sample the next token based on corrected distribution
return final_ids # Return the extended sequence
# Example usage
# current_tokens = ... # Initial sequence
# new_tokens = speculative_decode_step(
# large_model, small_model, current_tokens, k=5
# )
Speculative decoding represents a promising direction for accelerating LLM inference, particularly valuable in latency-sensitive applications. While it introduces additional complexity compared to standard autoregressive decoding, the potential for substantial speedups often justifies the effort, especially when combined with other optimization techniques like KV caching and optimized attention kernels.
© 2025 ApX Machine Learning