Standard batch processing, a foundation of modern GPU-accelerated deep learning, relies on performing the same operation uniformly across a contiguous block of input data. This breaks down with Mixture of Experts models. During inference, the gating network routes tokens within a single batch to different experts. This sparse activation pattern means that a contiguous input batch is scattered across multiple, independent computational paths, rendering a single, large matrix multiplication impossible.This behavior, often called the "gather-scatter" problem, is the primary source of computational inefficiency in MoE inference. If not handled correctly, it leads to a series of small, inefficient matrix operations that underutilize the GPU's parallel processing capabilities and dramatically increase latency. The solution is not to abandon batching, but to adapt it to the sparse nature of the workload through dynamic token grouping and dispatching.The Token Dispatching WorkflowThe core strategy is to transform the token-to-expert assignment from a sparse, irregular memory access pattern into a series of dense, regular computations. This is achieved by sorting and regrouping the tokens in a batch according to their designated expert after the gating network has made its routing decisions.The process unfolds in a few distinct steps:Gating and Routing: For an incoming batch of tokens, the gating network computes its logits and determines the target expert(s) for each token.Grouping and Permutation: Instead of immediately sending tokens to experts, the system first determines the expert assignments for the entire batch. It then creates a permutation map to reorder the tokens, grouping all tokens assigned to Expert 1 together, all for Expert 2, and so on. This creates dense "micro-batches," one for each active expert.Dispatch and Compute: Each dense micro-batch is dispatched to its corresponding expert. Since the expert now receives a contiguous block of tokens, it can perform its computation (e.g., the feed-forward network) using a single, efficient matrix multiplication.Unsorting: The output representations from each expert must be returned to their original positions in the token sequence. The inverse of the initial permutation map is used to scatter the results back into a new tensor that preserves the original token order.The diagram below illustrates this workflow. An initial batch of tokens is routed, sorted into expert-specific groups, processed, and then reassembled.digraph G { rankdir=TB; splines=ortho; node [shape=box, style="rounded,filled", fontname="Arial", fillcolor="#e9ecef"]; edge [fontname="Arial"]; subgraph cluster_input { label="1. Input Batch"; style=filled; color="#f8f9fa"; Input [label="Batch of Tokens\n(T1, T2, T3, T4, T5, T6)", fillcolor="#a5d8ff"]; } subgraph cluster_routing { label="2. Gating Network"; style=filled; color="#f8f9fa"; Router [label="Router", shape=diamond, fillcolor="#ffc078"]; } subgraph cluster_dispatch { label="3. Token Grouping & Dispatch"; style=filled; color="#f8f9fa"; node[style=filled]; GroupedE1 [label="Micro-batch for E1\n(T2, T5)", fillcolor="#b2f2bb"]; GroupedE2 [label="Micro-batch for E2\n(T1, T4, T6)", fillcolor="#b2f2bb"]; GroupedE3 [label="Micro-batch for E3\n(T3)", fillcolor="#b2f2bb"]; } subgraph cluster_experts { label="4. Expert Computation"; style=filled; color="#f8f9fa"; node[style=filled, shape=ellipse]; Expert1 [label="Expert 1", fillcolor="#69db7c"]; Expert2 [label="Expert 2", fillcolor="#69db7c"]; Expert3 [label="Expert 3", fillcolor="#69db7c"]; Expert4 [label="Expert 4 (Inactive)", fillcolor="#dee2e6"]; } subgraph cluster_output { label="5. Unsorting & Output"; style=filled; color="#f8f9fa"; Output [label="Output Batch\n(in original order)", fillcolor="#a5d8ff"]; } Input -> Router [label="Routing assignments calculated"]; Router -> GroupedE1 [style=dashed, arrowhead=none]; Router -> GroupedE2 [style=dashed, arrowhead=none]; Router -> GroupedE3 [style=dashed, arrowhead=none]; GroupedE1 -> Expert1 [label="Dispatch"]; GroupedE2 -> Expert2 [label="Dispatch"]; GroupedE3 -> Expert3 [label="Dispatch"]; Expert1 -> Output [label="Gather & Unsort", style=dashed]; Expert2 -> Output [style=dashed]; Expert3 -> Output [style=dashed]; {rank=same; GroupedE1; GroupedE2; GroupedE3;} {rank=same; Expert1; Expert2; Expert3; Expert4;} }Token dispatching workflow. Tokens are routed, permuted into expert-specific groups, computed densely, and then un-permuted back to their original sequence order.Handling Imbalanced Loads and CapacityA significant challenge in this process is load imbalance. The routing decisions are dynamic, so for any given batch, some experts may be assigned many tokens while others receive few or none. This creates micro-batches of varying sizes.{"data":[{"x":["E1","E2","E3","E4","E5","E6","E7","E8"],"y":[42,98,15,0,112,5,71,0],"type":"bar","marker":{"color":["#40c057","#228be6","#f76707","#adb5bd","#228be6","#f76707","#40c057","#adb5bd"]},"name":"Tokens per Expert"}],"layout":{"title":{"text":"Token Distribution Across Experts for a Single Batch"},"xaxis":{"title":"Expert ID"},"yaxis":{"title":"Number of Tokens Assigned"},"bargap":0.2}}An example of an imbalanced token load across eight experts for a single inference batch. Experts 4 and 8 are inactive, while Expert 5 is heavily loaded.To manage this imbalance and maintain a regular computational structure, systems often employ padding. The micro-batches for all experts are padded to a uniform size, typically determined by the size of the largest micro-batch in the global batch. While this introduces some wasted computation on the padded elements, it simplifies the execution graph and often yields higher overall throughput than handling many differently-sized operations.This is also where the capacity_factor from training becomes relevant again. During inference, if the number of tokens routed to an expert exceeds its defined capacity (batch_size / num_experts * capacity_factor), the excess tokens are typically "dropped." Their representations pass through the MoE layer unchanged, equivalent to being processed by a residual connection. This is a direct trade-off: a lower capacity saves memory and computation but risks a quality degradation if too many tokens are dropped. For production systems, this value must be tuned based on observed token distributions and latency requirements.Advanced Scheduling with Custom KernelsThe gather-compute-scatter workflow, while effective, introduces overhead from data movement and multiple kernel launches on the GPU. The permutation and inverse permutation steps require reading and writing the entire batch of token data, which can be a bottleneck.For highly optimized inference servers, these operations can be fused into a single, custom GPU kernel using frameworks like Triton or by writing CUDA directly. Instead of separate steps, a fused kernel can:Read a token from the original, unsorted batch.Identify its target expert.Load the weights for that specific expert.Perform the computation.Write the result directly to the correct location in the final output tensor.This approach minimizes data movement between GPU memory and its compute cores, significantly reducing the overhead of the dispatch logic.Here is a code representation of what a fused kernel might accomplish. In JAX, it is very simple:# Code for a fused token dispatch kernel @triton.jit def fused_moe_kernel(tokens_in, tokens_out, router_indices, expert_weights): # Get the unique ID for this instance of the kernel token_id = tl.program_id(0) # 1. Read the token's assigned expert index expert_idx = router_indices[token_id] # 2. Load the input token data input_data = tokens_in[token_id, :] # 3. Load the corresponding expert weights # This is a simplification; in practice, this is complex w1 = expert_weights[expert_idx, 0, :, :] w2 = expert_weights[expert_idx, 1, :, :] # 4. Perform the expert computation hidden = tl.dot(input_data, w1) hidden = tl.nn.relu(hidden) output_data = tl.dot(hidden, w2) # 5. Write the result to the correct output position tokens_out[token_id, :] = output_dataIn Pytorch:import torch # Equivalent PyTorch snippet illustrating the "gather-scatter" problem # This is NOT a fused kernel, but rather shows the standard # (and less efficient) way to implement the token dispatch logic # without custom kernel fusion. def pytorch_moe_dispatch(tokens_in, router_indices, experts_list): """ Simulates the MoE dispatch without custom kernel fusion in PyTorch. This demonstrates the gather-scatter approach, which is less efficient than a fused kernel. Args: tokens_in (torch.Tensor): Input tokens of shape (num_tokens, hidden_dim). router_indices (torch.Tensor): 1D tensor of expert assignments for each token, shape (num_tokens,). experts_list (list of torch.nn.Module): A list of expert modules, where each expert[i] is a feed-forward network. Returns: torch.Tensor: Output tokens in their original order, shape (num_tokens, hidden_dim). """ num_tokens, hidden_dim = tokens_in.shape num_experts = len(experts_list) # Initialize a list to hold outputs from each expert expert_outputs = [torch.zeros_like(tokens_in) for _ in range(num_experts)] # Initialize a mask to keep track of tokens routed to each expert expert_masks = [router_indices == i for i in range(num_experts)] # 1. Gather and Dispatch to Experts # This involves iterating through experts and gathering relevant tokens for i in range(num_experts): mask = expert_masks[i] # Select tokens assigned to the current expert (Gather) # This creates a non-contiguous tensor and can be inefficient tokens_for_expert = tokens_in[mask] if tokens_for_expert.numel() > 0: # Only process if there are tokens # Perform expert computation (e.g., feed-forward network) processed_tokens = experts_list[i](tokens_for_expert) # Store the processed tokens back into the expert_outputs structure # This is part of the "scatter" logic but within each expert's processing # We use a placeholder here for the scatter operation within the expert expert_outputs[i][mask] = processed_tokens # 2. Combine outputs from all experts and Unsort # Summing up here effectively "unsorts" because each expert's output # is already placed at the correct original token indices due to the mask. # In a more explicit unsorting scenario, you would use torch.index_put_ # or similar based on a permutation map. final_output = torch.sum(torch.stack(expert_outputs), dim=0) return final_output # Example Usage (Illustrative - not a full runnable example without expert modules) if __name__ == '__main__': # Define some dummy input tokens and router indices batch_size = 6 hidden_dim = 128 num_experts = 4 # Dummy input tokens dummy_tokens_in = torch.randn(batch_size, hidden_dim) # Dummy router assignments (e.g., from a gating network) # Each token is assigned to one expert (0 to num_experts-1) dummy_router_indices = torch.randint(0, num_experts, (batch_size,)) # Dummy expert modules (simple linear layers for demonstration) class DummyExpert(torch.nn.Module): def __init__(self, hidden_dim): super().__init__() self.linear1 = torch.nn.Linear(hidden_dim, hidden_dim * 2) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(hidden_dim * 2, hidden_dim) def forward(self, x): return self.linear2(self.relu(self.linear1(x))) dummy_experts = [DummyExpert(hidden_dim) for _ in range(num_experts)] print("PyTorch MoE Dispatch Simulation (showing gather-scatter concept):") print(f"Input tokens shape: {dummy_tokens_in.shape}") print(f"Router indices: {dummy_router_indices.tolist()}") # Perform the simulated MoE dispatch output_tokens = pytorch_moe_dispatch(dummy_tokens_in, dummy_router_indices, dummy_experts) print(f"Output tokens shape: {output_tokens.shape}") print("\nNote: This PyTorch code demonstrates the gather-scatter operations,") print("which typically involve explicit indexing and looping per expert,") print("leading to less efficiency compared to fused kernels like the Triton example.") print("Highly optimized PyTorch MoE implementations often rely on custom CUDA extensions") print("or specialized libraries (e.g., fairseq's MoE, Megatron-LM) to achieve fusion-like performance.")While implementing custom kernels requires specialized expertise, it represents the state-of-the-art for minimizing MoE inference latency. For most applications, leveraging optimized libraries like vLLM or DeepSpeed-Inference, which have these techniques built-in, provides a practical path to high-performance serving without requiring manual kernel development. Ultimately, an effective batching strategy is not an optional enhancement but a fundamental requirement for deploying MoE models in production.