Masterclass
While optimizing individual token generation with techniques like KV caching and specialized attention mechanisms significantly reduces latency and memory overhead per step, maximizing the overall throughput of an LLM inference server requires processing multiple requests concurrently. GPUs, the workhorses of deep learning, achieve peak performance when performing large matrix multiplications. Processing requests one by one often leaves the GPU underutilized, as the control logic and data movement between steps can dominate the computation time for a single sequence. Batching strategies address this by grouping multiple inference requests together, allowing the GPU to process them in parallel, thus amortizing the overheads and increasing the number of tokens generated per second.
The simplest approach is static batching. Here, the inference server waits until a predefined number of requests (batch_size
) arrive or a timeout occurs. These requests are then grouped, padded to the length of the longest sequence in the batch, and processed together.
Padding involves adding special tokens to the shorter sequences in the batch so that all sequences have the same length. This creates a uniform tensor shape required for efficient matrix multiplication on the GPU. An attention mask is crucial here to ensure that the model does not attend to these padding tokens during the self-attention calculations.
import torch
# Example: Preparing a static batch
requests = [
"This is sequence one.",
"A shorter sequence.",
"This is the third sequence, and it's quite long."
]
# Assume tokenizer adds special tokens and converts to IDs
# Tokenized sequences (simplified IDs)
seq1_ids = [101, 2023, 2003, 5537, 2028, 1012, 102] # len=7
seq2_ids = [101, 1037, 9087, 5537, 1012, 102] # len=6
seq3_ids = [
101, 2023, 2003, 1996, 2353, 5537, 1010, 1998, 2009, 1005, 1055,
3747, 2146, 1012, 102
] # len=15
sequences = [seq1_ids, seq2_ids, seq3_ids]
max_len = max(len(seq) for seq in sequences) # max_len = 15
padding_token_id = 0 # Assuming PAD token ID is 0
padded_sequences = []
attention_masks = []
for seq in sequences:
pad_len = max_len - len(seq)
padded_seq = seq + [padding_token_id] * pad_len
# Mask out padding
attention_mask = ([1] * len(seq) +
[0] * pad_len)
padded_sequences.append(padded_seq)
attention_masks.append(attention_mask)
# Convert to tensors for model input
input_ids = torch.tensor(padded_sequences)
attention_mask = torch.tensor(attention_masks)
print("Input IDs Shape:", input_ids.shape) # torch.Size([3, 15])
print("Attention Mask Shape:", attention_mask.shape) # torch.Size([3, 15])
# Now input_ids and attention_mask can be fed to the model
# model(input_ids=input_ids,
# attention_mask=attention_mask, ...)
Limitations of Static Batching:
Dynamic batching aims to improve GPU utilization over static batching by grouping requests that arrive within a small time window, rather than waiting for a fixed batch size. This often leads to batches with more varied sequence lengths. While it still requires padding, the server can potentially start processing batches more frequently.
The core idea is flexibility. Instead of a fixed batch_size
, the server might accumulate requests for a short duration (e.g., 10 milliseconds) and then process whatever has arrived as a batch.
# Server logic for dynamic batching
import time
import queue
request_queue = queue.Queue()
MAX_WAIT_TIME_MS = 10
MAX_BATCH_SIZE = 16 # Optional upper limit
def process_batch(batch):
# (Similar padding logic as static batching)
# ... tokenize, pad, create masks ...
# input_ids, attention_mask = prepare_batch(batch)
# with torch.no_grad():
# outputs = model(input_ids=input_ids, attention_mask=attention_mask)
# ... handle outputs ...
print(f"Processed batch of size {len(batch)}")
while True:
batch = []
start_time = time.time()
while True:
try:
# Non-blocking check for new requests
request = request_queue.get_nowait()
batch.append(request)
if len(batch) >= MAX_BATCH_SIZE: # If max size reached
break
except queue.Empty:
# Check if wait time exceeded
if (time.time() - start_time) * 1000 > MAX_WAIT_TIME_MS:
break
# Optional short sleep to avoid busy-waiting
time.sleep(0.001)
if batch:
process_batch(batch)
else:
# No requests, maybe sleep longer
time.sleep(0.01)
While dynamic batching is better than static, it still suffers from padding inefficiency and the issue that the entire batch progresses at the pace of the slowest (longest) sequence during autoregressive generation. If one sequence needs 500 tokens and others need only 50, the batch slot remains occupied until the 500-token generation is complete, leading to GPU idling for the completed sequences within the batch (often called "bubbles").
Continuous batching (also known as iteration-level scheduling or dynamic splitfuse) is a more sophisticated technique designed to maximize throughput by addressing the limitations of simpler batching methods. It decouples the batching at the server level from the iteration steps of the generation process.
Instead of processing a fixed batch until all sequences are complete, continuous batching operates on a per-iteration basis:
Simplified flow of continuous batching. An iteration batch is formed from active requests, processed by the GPU, and the pool is updated with new tokens, completed sequences, and incoming requests.
The key advantage is that the GPU stays busy as long as there are any active sequences ready for generation. When one sequence finishes, its slot in the next iteration's batch can immediately be taken by another sequence from the pool or a newly arrived request. This avoids the "bubbles" of static/dynamic batching where the GPU waits for the longest sequence in a batch to finish. Systems like Orca, vLLM, TensorRT-LLM, and Text Generation Inference (TGI) implement variants of continuous batching.
transformers
library handles padding and attention masking automatically when you provide a list of inputs to the model or tokenizer. Dedicated inference servers (Triton, TorchServe, TGI, vLLM) offer more advanced batching capabilities, including dynamic and continuous batching, optimized for production workloads.In summary, batching is indispensable for efficient LLM inference. While static batching is simple, dynamic and particularly continuous batching offer substantial improvements in GPU utilization and overall throughput by processing multiple requests in parallel and dynamically managing the workload across generation steps. Choosing and tuning the right batching strategy is a significant aspect of deploying LLMs cost-effectively at scale.
© 2025 ApX Machine Learning