Masterclass
Handling concurrent requests efficiently is a significant challenge when serving large language models (LLMs). Unlike typical stateless web services where requests are often independent and quick to process, LLM inference has unique characteristics:
Simply processing requests sequentially leads to extremely low throughput and poor hardware utilization, as the expensive GPU sits idle most of the time. Launching a separate model instance for each concurrent request is prohibitively expensive due to the massive memory requirements. Therefore, batching techniques are essential for optimizing throughput and cost.
The most straightforward approach is static batching. Incoming requests are collected until a predefined batch size is reached, or a timeout occurs. The server then pads all sequences in the batch to the length of the longest sequence and processes the entire batch in a single forward pass.
Pros:
Cons:
Consider a batch with two sequences: A (10 tokens) and B (100 tokens). Using static batching, sequence A will be padded to 100 tokens. The GPU computes outputs for all 100 positions for both sequences, even though 90 positions for sequence A are just padding.
Dynamic batching offers an improvement by forming batches more flexibly. Instead of waiting for a fixed batch size, the server collects requests arriving within a short time window (e.g., 10 milliseconds) and batches them together dynamically, up to a maximum batch size limit.
Requests arriving within a defined time window are grouped into batches for processing.
Pros:
Cons:
Continuous batching (also known as in-flight batching or iteration-level batching) is a more advanced technique implemented in modern LLM serving frameworks like vLLM, Text Generation Inference (TGI), and NVIDIA Triton's TensorRT-LLM backend. It fundamentally changes how batching is performed during autoregressive decoding.
Instead of batching entire requests, continuous batching operates at the level of individual generation steps (iterations). The core idea is:
Continuous batching processes the next token generation for all active sequences together in each iteration, allowing new requests to join efficiently.
Pros:
Cons:
Continuous batching's effectiveness relies heavily on efficient KV cache management. Storing the entire KV cache for potentially hundreds of concurrent sequences contiguously is often infeasible due to memory fragmentation.
Techniques like PagedAttention, pioneered by vLLM, address this. Inspired by virtual memory and paging in operating systems, PagedAttention allocates the KV cache in non-contiguous, fixed-size blocks (pages). This allows:
While a detailed implementation is complex, understanding the concept highlights how memory management is intertwined with achieving high concurrency in LLM serving.
Implementing these advanced batching strategies from scratch is complex. Thankfully, specialized serving frameworks handle much of this complexity:
When configuring these systems, you'll typically encounter parameters like:
max_num_batched_tokens
: The maximum total number of tokens (sum across all sequences) that can be processed in a single iteration batch. This helps control GPU memory usage.max_num_seqs
: The maximum number of concurrent sequences the server can handle.max_seq_len
: Maximum sequence length supported.Here's a highly simplified Python snippet illustrating the core loop logic managing requests, without detailing KV cache paging or framework specifics:
# Python Server Snippet (Illustrative - Continuous Batching Idea)
import torch
import queue
import threading
import time
import uuid
# Assume 'model' is loaded LLM supporting a step-by-step forward method
# Assume 'tokenizer' is available
request_queue = queue.Queue()
# active_requests stores state:
# {req_id: {'prompt_tokens', 'output_tokens', 'kv_cache', 'is_finished'}}
active_requests = {}
# result_store stores finished outputs: {req_id: 'full_output_string'}
result_store = {}
MAX_CONCURRENT_REQUESTS = 16 # Limit based on memory/compute
SCHEDULING_INTERVAL = 0.005 # How often the scheduler runs (5ms)
def get_next_batch():
""" Gathers requests ready for the next inference step. """
batch_req_ids = []
batch_input_tokens = []
batch_kv_caches = []
# Prioritize existing active requests
for req_id, state in list(active_requests.items()): # Iterate over a copy
if not state['is_finished']:
batch_req_ids.append(req_id)
# Determine the token to feed for the next step
if not state['output_tokens']: # First step after prompt
input_token = state['prompt_tokens'][:, -1:]
else: # Subsequent steps
input_token = state['output_tokens'][:, -1:]
batch_input_tokens.append(input_token)
batch_kv_caches.append(state['kv_cache']) # Might be None initially
# Add new requests if capacity allows
available_slots = MAX_CONCURRENT_REQUESTS - len(batch_req_ids)
for _ in range(available_slots):
try:
new_req_id, prompt_str = request_queue.get_nowait()
prompt_tokens = tokenizer.encode(
prompt_str, return_tensors="pt"
).to(model.device)
active_requests[new_req_id] = {
'prompt_tokens': prompt_tokens,
'output_tokens': None, # Initialize output
'kv_cache': None, # Initialize KV cache
'is_finished': False
}
# Add to the current batch for its first step
batch_req_ids.append(new_req_id)
batch_input_tokens.append(prompt_tokens[:, -1:]) # Use last prompt token
batch_kv_caches.append(None)
print(f"Scheduler: Added new request {new_req_id}")
except queue.Empty:
break # No more new requests waiting
return (
batch_req_ids,
batch_input_tokens,
batch_kv_caches
)
def inference_scheduler_loop():
""" Main loop to schedule and run inference batches. """
while True:
start_time = time.time()
# 1. Get requests for the next step
(
batch_req_ids,
batch_input_tokens,
batch_kv_caches
) = get_next_batch()
if not batch_req_ids:
time.sleep(SCHEDULING_INTERVAL) # Wait if nothing to process
continue
# 2. Prepare batch for model
# (simplified: needs proper collation/padding if model requires)
# Assuming model handles list of tensors and KV caches
# Simple concat won't work for varying lengths without padding
# input_ids = torch.cat(batch_input_tokens, dim=0)
# 3. Run inference step (highly simplified)
# This is where the actual model forward pass happens.
# It takes the current input tokens and past KV states,
# returns logits for the next token and the updated KV states.
# === Mock Implementation ===
logits = torch.randn(
len(batch_req_ids), 1, tokenizer.vocab_size
).to(model.device)
next_token_ids = torch.argmax(logits, dim=-1) # Shape: [batch_size, 1]
updated_kv_caches = [
f"kv_{req_id}_step_{len(active_requests[req_id].get('output_tokens',[]))}"
for req_id in batch_req_ids
] # Mock KV update
# === End Mock ===
print(f"Scheduler: Processed batch of size {len(batch_req_ids)}")
# 4. Update request states
finished_ids = []
for i, req_id in enumerate(batch_req_ids):
state = active_requests[req_id]
current_next_token = next_token_ids[i:i+1] # Keep shape [1, 1]
if state['output_tokens'] is None:
state['output_tokens'] = current_next_token
else:
state['output_tokens'] = torch.cat(
[state['output_tokens'], current_next_token],
dim=1
)
state['kv_cache'] = updated_kv_caches[i] # Store updated cache
# Check termination (EOS token or max length)
# Simplified: Assume EOS token ID is 2
if (current_next_token.item() == 2 or
(state['output_tokens'] is not None and
state['output_tokens'].shape[1] > 100)):
state['is_finished'] = True
full_sequence = torch.cat(
[state['prompt_tokens'], state['output_tokens']],
dim=1
)
result_store[req_id] = tokenizer.decode(
full_sequence[0], skip_special_tokens=True
)
finished_ids.append(req_id)
print(f"Scheduler: Finished request {req_id}")
# 5. Clean up finished requests from active pool
for req_id in finished_ids:
if req_id in active_requests:
del active_requests[req_id]
# Ensure loop runs roughly at the desired interval
elapsed_time = time.time() - start_time
sleep_time = max(0, SCHEDULING_INTERVAL - elapsed_time)
time.sleep(sleep_time)
# --- Example Usage ---
# Start scheduler in a background thread
# scheduler_thread = threading.Thread(
# target=inference_scheduler_loop, daemon=True
# )
# scheduler_thread.start()
# Simulate incoming requests
# req_id_1 = str(uuid.uuid4())
# request_queue.put((
# req_id_1,
# "Explain the theory of relativity in simple terms:"
# ))
# req_id_2 = str(uuid.uuid4())
# request_queue.put((
# req_id_2,
# "Write a short poem about autumn leaves:"
# ))
# Later, check for results
# if req_id_1 in result_store:
# print("Result 1:", result_store[req_id_1])
In summary, handling concurrent requests for LLM serving necessitates moving past simple static or dynamic batching. Continuous batching, combined with sophisticated memory management like PagedAttention, provides the highest throughput and best resource utilization by processing generation steps iteratively across all active requests. Using specialized serving frameworks is generally the most practical way to implement these advanced techniques in production.
Was this section helpful?
© 2025 ApX Machine Learning