Standard batch processing, a foundation of modern GPU-accelerated deep learning, relies on performing the same operation uniformly across a contiguous block of input data. This breaks down with Mixture of Experts models. During inference, the gating network routes tokens within a single batch to different experts. This sparse activation pattern means that a contiguous input batch is scattered across multiple, independent computational paths, rendering a single, large matrix multiplication impossible.
This behavior, often called the "gather-scatter" problem, is the primary source of computational inefficiency in MoE inference. If not handled correctly, it leads to a series of small, inefficient matrix operations that underutilize the GPU's parallel processing capabilities and dramatically increase latency. The solution is not to abandon batching, but to adapt it to the sparse nature of the workload through dynamic token grouping and dispatching.
The core strategy is to transform the token-to-expert assignment from a sparse, irregular memory access pattern into a series of dense, regular computations. This is achieved by sorting and regrouping the tokens in a batch according to their designated expert after the gating network has made its routing decisions.
The process unfolds in a few distinct steps:
The diagram below illustrates this workflow. An initial batch of tokens is routed, sorted into expert-specific groups, processed, and then reassembled.
Token dispatching workflow. Tokens are routed, permuted into expert-specific groups, computed densely, and then un-permuted back to their original sequence order.
A significant challenge in this process is load imbalance. The routing decisions are dynamic, so for any given batch, some experts may be assigned many tokens while others receive few or none. This creates micro-batches of varying sizes.
An example of an imbalanced token load across eight experts for a single inference batch. Experts 4 and 8 are inactive, while Expert 5 is heavily loaded.
To manage this imbalance and maintain a regular computational structure, systems often employ padding. The micro-batches for all experts are padded to a uniform size, typically determined by the size of the largest micro-batch in the global batch. While this introduces some wasted computation on the padded elements, it simplifies the execution graph and often yields higher overall throughput than handling many differently-sized operations.
This is also where the capacity_factor from training becomes relevant again. During inference, if the number of tokens routed to an expert exceeds its defined capacity (batch_size / num_experts * capacity_factor), the excess tokens are typically "dropped." Their representations pass through the MoE layer unchanged, equivalent to being processed by a residual connection. This is a direct trade-off: a lower capacity saves memory and computation but risks a quality degradation if too many tokens are dropped. For production systems, this value must be tuned based on observed token distributions and latency requirements.
The gather-compute-scatter workflow, while effective, introduces overhead from data movement and multiple kernel launches on the GPU. The permutation and inverse permutation steps require reading and writing the entire batch of token data, which can be a bottleneck.
For highly optimized inference servers, these operations can be fused into a single, custom GPU kernel using frameworks like Triton or by writing CUDA directly. Instead of separate steps, a fused kernel can:
This approach minimizes data movement between GPU memory and its compute cores, significantly reducing the overhead of the dispatch logic.
Here is a code representation of what a fused kernel might accomplish. In JAX, it is very simple:
# Code for a fused token dispatch kernel
@triton.jit
def fused_moe_kernel(tokens_in, tokens_out, router_indices, expert_weights):
# Get the unique ID for this instance of the kernel
token_id = tl.program_id(0)
# 1. Read the token's assigned expert index
expert_idx = router_indices[token_id]
# 2. Load the input token data
input_data = tokens_in[token_id, :]
# 3. Load the corresponding expert weights
# This is a simplification; in practice, this is complex
w1 = expert_weights[expert_idx, 0, :, :]
w2 = expert_weights[expert_idx, 1, :, :]
# 4. Perform the expert computation
hidden = tl.dot(input_data, w1)
hidden = tl.nn.relu(hidden)
output_data = tl.dot(hidden, w2)
# 5. Write the result to the correct output position
tokens_out[token_id, :] = output_data
In Pytorch:
import torch
# Equivalent PyTorch snippet illustrating the "gather-scatter" problem
# This is NOT a fused kernel, but rather shows the standard
# (and less efficient) way to implement the token dispatch logic
# without custom kernel fusion.
def pytorch_moe_dispatch(tokens_in, router_indices, experts_list):
"""
Simulates the MoE dispatch without custom kernel fusion in PyTorch.
This demonstrates the gather-scatter approach, which is less efficient
than a fused kernel.
Args:
tokens_in (torch.Tensor): Input tokens of shape (num_tokens, hidden_dim).
router_indices (torch.Tensor): 1D tensor of expert assignments for each token,
shape (num_tokens,).
experts_list (list of torch.nn.Module): A list of expert modules, where
each expert[i] is a feed-forward network.
Returns:
torch.Tensor: Output tokens in their original order, shape (num_tokens, hidden_dim).
"""
num_tokens, hidden_dim = tokens_in.shape
num_experts = len(experts_list)
# Initialize a list to hold outputs from each expert
expert_outputs = [torch.zeros_like(tokens_in) for _ in range(num_experts)]
# Initialize a mask to keep track of tokens routed to each expert
expert_masks = [router_indices == i for i in range(num_experts)]
# 1. Gather and Dispatch to Experts
# This involves iterating through experts and gathering relevant tokens
for i in range(num_experts):
mask = expert_masks[i]
# Select tokens assigned to the current expert (Gather)
# This creates a non-contiguous tensor and can be inefficient
tokens_for_expert = tokens_in[mask]
if tokens_for_expert.numel() > 0: # Only process if there are tokens
# Perform expert computation (e.g., feed-forward network)
processed_tokens = experts_list[i](tokens_for_expert)
# Store the processed tokens back into the expert_outputs structure
# This is part of the "scatter" logic but within each expert's processing
# We use a placeholder here for the scatter operation within the expert
expert_outputs[i][mask] = processed_tokens
# 2. Combine outputs from all experts and Unsort
# Summing up here effectively "unsorts" because each expert's output
# is already placed at the correct original token indices due to the mask.
# In a more explicit unsorting scenario, you would use torch.index_put_
# or similar based on a permutation map.
final_output = torch.sum(torch.stack(expert_outputs), dim=0)
return final_output
# Example Usage (Illustrative - not a full runnable example without expert modules)
if __name__ == '__main__':
# Define some dummy input tokens and router indices
batch_size = 6
hidden_dim = 128
num_experts = 4
# Dummy input tokens
dummy_tokens_in = torch.randn(batch_size, hidden_dim)
# Dummy router assignments (e.g., from a gating network)
# Each token is assigned to one expert (0 to num_experts-1)
dummy_router_indices = torch.randint(0, num_experts, (batch_size,))
# Dummy expert modules (simple linear layers for demonstration)
class DummyExpert(torch.nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim * 2)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
dummy_experts = [DummyExpert(hidden_dim) for _ in range(num_experts)]
print("PyTorch MoE Dispatch Simulation (showing gather-scatter concept):")
print(f"Input tokens shape: {dummy_tokens_in.shape}")
print(f"Router indices: {dummy_router_indices.tolist()}")
# Perform the simulated MoE dispatch
output_tokens = pytorch_moe_dispatch(dummy_tokens_in, dummy_router_indices, dummy_experts)
print(f"Output tokens shape: {output_tokens.shape}")
print("\nNote: This PyTorch code demonstrates the gather-scatter operations,")
print("which typically involve explicit indexing and looping per expert,")
print("leading to less efficiency compared to fused kernels like the Triton example.")
print("Highly optimized PyTorch MoE implementations often rely on custom CUDA extensions")
print("or specialized libraries (e.g., fairseq's MoE, Megatron-LM) to achieve fusion-like performance.")
While implementing custom kernels requires specialized expertise, it represents the state-of-the-art for minimizing MoE inference latency. For most applications, leveraging optimized libraries like vLLM or DeepSpeed-Inference, which have these techniques built-in, provides a practical path to high-performance serving without requiring manual kernel development. Ultimately, an effective batching strategy is not an optional enhancement but a fundamental requirement for deploying MoE models in production.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with