Just launched on LinkedIn! Follow for updates on AI/ML research and practical tips.

Follow on LinkedIn

How To Implement Mixture of Experts (MoE) in PyTorch

Wei Ming T.

By Wei Ming T. on Apr 18, 2025

Mixture of Experts (MoE) represents a significant advancement in scaling large neural networks efficiently. Instead of activating the entire model for every input, MoE selectively engages sparse parts, known as 'experts'. This approach allows for models with vastly increased parameter counts while keeping computational costs manageable during inference.

Successfully implementing MoE requires understanding its core components and training dynamics. We will walk through building an MoE layer from scratch using PyTorch, covering the gating mechanism, expert networks, and crucial training considerations like load balancing.

What is Mixture of Experts (MoE)?

An MoE layer replaces a standard dense feed-forward network (FFN) block within a larger model, such as a Transformer. It consists of two primary components:

  1. Expert Networks: A set of smaller, independent feed-forward networks. These are the 'experts' specialized in processing different aspects of the input data.
  2. Gating Network: A router network that determines which expert(s) should process each part of the input (e.g., each token in a sequence).

For a given input token, the gating network outputs probabilities or scores indicating the suitability of each expert. Typically, only the top K experts (where K is much smaller than the total number of experts) are selected and activated for that token. The final output of the MoE layer is a weighted combination of the outputs from the activated experts, using the gating network's scores as weights.

High-level structure of a Mixture of Experts layer. Input is routed to selected experts via the gating network, and outputs are combined.

This sparse activation pattern is key to MoE's efficiency. While the total number of parameters can be very large (sum of parameters in all experts plus the gating network), the computation required for a single input token only involves the gating network and the K selected experts.

The Gating Network

The gating network is responsible for routing decisions. It takes the input token representation (e.g., the output of the self-attention layer in a Transformer) and produces a distribution over the available experts.

A common implementation uses a simple linear layer followed by a Softmax function.

import torch
import torch.nn as nn
import torch.nn.functional as F

class TopKGate(nn.Module):
    """Gate module to select top k experts."""
    def __init__(self, input_dim, num_experts, k=1):
        super().__init__()
        self.k = k
        # Linear layer to compute logits for experts
        self.gate_linear = nn.Linear(input_dim, num_experts, bias=False)

    def forward(self, x):
        # x shape: [batch_size * seq_len, input_dim]
        # logits shape: [batch_size * seq_len, num_experts]
        logits = self.gate_linear(x)
        
        # Select top-k experts
        # top_k_logits shape: [batch_size * seq_len, k]
        # top_k_indices shape: [batch_size * seq_len, k]
        top_k_logits, top_k_indices = torch.topk(
            logits, self.k, dim=-1
        )
        
        # Apply softmax to top-k logits for weights
        # top_k_weights shape: [batch_size * seq_len, k]
        top_k_weights = F.softmax(top_k_logits, dim=-1)
        
        # Create a sparse weight matrix for combining outputs
        # full_weights shape: [batch_size * seq_len, num_experts]
        full_weights = torch.zeros_like(logits)
        full_weights.scatter_(1, top_k_indices, top_k_weights)
        
        return full_weights, top_k_indices # Return weights and indices

This TopKGate selects the k experts with the highest logits for each token and calculates their normalized weights using Softmax. It returns both the dense weight matrix (with zeros for non-selected experts) and the indices of the chosen experts, which are needed for routing.

The Expert Networks

Experts are typically identical feed-forward networks (FFNs). They process the input tokens routed to them by the gating network. Each expert operates independently.

class Expert(nn.Module):
    """A simple feed-forward expert network."""
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.activation = nn.ReLU() # Or GeLU, SiLU etc.

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x

In an MoE layer, you would instantiate multiple copies of this Expert module, usually held within an nn.ModuleList.

Implementing the MoE Layer in PyTorch

Now, let's combine the gating network and the experts into a complete MoE layer.

class MoELayer(nn.Module):
    """Mixture of Experts layer."""
    def __init__(self, input_dim, output_dim, num_experts, k=1, 
                 expert_hidden_dim=None):
        super().__init__()
        self.num_experts = num_experts
        self.k = k
        self.output_dim = output_dim

        if expert_hidden_dim is None:
            expert_hidden_dim = input_dim * 4 # Common practice

        self.gate = TopKGate(input_dim, num_experts, k)
        self.experts = nn.ModuleList(
            [Expert(input_dim, expert_hidden_dim, output_dim) 
             for _ in range(num_experts)]
        )

    def forward(self, x):
        # Assume x shape: [batch_size, seq_len, input_dim]
        original_shape = x.shape
        x = x.view(-1, original_shape[-1]) # Flatten to [N, input_dim] where N = batch*seq_len

        # Get gating weights and expert indices
        # gate_weights: [N, num_experts], top_k_indices: [N, k]
        gate_weights, top_k_indices = self.gate(x)

        # Initialize final output tensor
        final_output = torch.zeros(x.shape[0], self.output_dim, 
                                   device=x.device, dtype=x.dtype)
        
        # Get indices for batch processing
        # flat_top_k_indices: [N * k]
        flat_top_k_indices = top_k_indices.view(-1)
        
        # Map tokens to their assigned experts
        # Create a flat tensor of inputs for batching across experts
        # flat_x: [N * k, input_dim]
        flat_x = x.repeat_interleave(self.k, dim=0)

        # Dispatch tokens to experts and compute outputs
        expert_outputs = []
        for i in range(self.num_experts):
            # Find indices of tokens assigned to expert i
            # idx: [num_tokens_for_expert_i]
            idx = torch.where(flat_top_k_indices == i)[0]
            
            if idx.numel() > 0:
                # Process tokens assigned to this expert
                expert_input = flat_x[idx]
                expert_output = self.experts[i](expert_input)
                
                # Store output and original indices
                expert_outputs.append((idx, expert_output))

        # Combine expert outputs using gating weights
        # We need to map the results back to the original token positions
        flat_gate_weights = gate_weights.view(-1, 1) # [N * num_experts, 1]

        for idx, output in expert_outputs:
            # Find the corresponding weights for these outputs
            # Need original token indices and expert indices
            original_indices = idx // self.k # Get original token index (0 to N-1)
            expert_indices = flat_top_k_indices[idx] # Which expert (0 to num_experts-1)
            
            # Gather the weights using original and expert indices
            weights = gate_weights[original_indices, expert_indices].unsqueeze(1)
            
            # Weight the expert output
            weighted_output = output * weights
            
            # Add to the final output tensor at the correct positions
            # Use index_add_ for scatter-add operation
            final_output.index_add_(0, original_indices, weighted_output)
            
        # Reshape back to original shape [batch_size, seq_len, output_dim]
        final_output = final_output.view(original_shape[0], original_shape[1], 
                                         self.output_dim)
        return final_output, gate_weights # Return output and weights for aux loss

This implementation performs the core MoE logic:

  1. Flattens the input batch and sequence dimensions.
  2. Uses the TopKGate to get weights and the indices of the top k experts for each token.
  3. Efficiently routes tokens to their assigned experts using indexing (torch.where). Note: More advanced implementations might use custom CUDA kernels for optimized routing, especially for large scale models, but this index-based approach illustrates the process.
  4. Computes outputs for batches assigned to each expert.
  5. Combines the expert outputs, weighted by the corresponding gating scores, using index_add_ to place results correctly in the output tensor.
  6. Reshapes the output back to the original [batch_size, seq_len, output_dim] format.
  7. Returns the final output and the gating weights (needed for the auxiliary loss).

Self-Correction during thought process: Initial thought was to iterate through tokens, but that's inefficient. Iterating through experts and processing tokens assigned to each expert in batches is much better. Also, correctly combining the outputs requires careful indexing back to the original token positions and using the right gating weights. index_add_ is suitable here. The code above reflects this improved approach.

Training Considerations

Training MoE models requires addressing a potential issue: load imbalance. If the gating network consistently routes most tokens to a few 'popular' experts, those experts get trained disproportionately, while others remain under-utilized. This undermines the benefit of having many specialized experts.

Load Balancing Loss

To encourage balanced routing, an auxiliary loss function is typically added to the main task loss (e.g., cross-entropy). A common form of this loss aims to minimize the coefficient of variation of the expert utilization.

Let NN be the total number of tokens in a batch and EE be the number of experts. Define nin_i as the number of tokens assigned to expert ii. The fraction of tokens assigned to expert ii is fi=ni/Nf_i = n_i / N. The average fraction of tokens processed by expert ii (averaged over the batch) is P_i = rac{1}{N} sum_{x} ext{gate_weights}(x)_i, where the sum is over all tokens xx in the batch and extgateweights(x)i ext{gate_weights}(x)_i is the gating weight for expert ii for token xx.

The auxiliary loss LauxL_{aux} can be defined as:

Laux=alphacdotEcdotsumi=1EfiPiL_{aux} = alpha cdot E cdot sum_{i=1}^{E} f_i P_i

Here, alphaalpha is a scaling coefficient (hyperparameter, e.g., 0.01). Minimizing this loss encourages both the number of tokens fif_i and the average gating weights PiP_i assigned to each expert to be similar, promoting load balance.

def calculate_load_balancing_loss(gate_weights, num_experts):
    """Calculates the load balancing loss.

    Args:
        gate_weights: Tensor of shape [batch_size * seq_len, num_experts]
        num_experts: Total number of experts.

    Returns:
        Scalar loss tensor.
    """
    # gate_weights is the output of the gate network (before top-k)
    # Need average routing probability per expert
    # and fraction of tokens routed to each expert
    
    num_tokens = gate_weights.shape[0]
    
    # Calculate fraction of tokens routed to each expert (f_i)
    # Use the weights directly as a proxy for assignment count
    # (Sum of weights for each expert across all tokens)
    tokens_per_expert = torch.sum(gate_weights, dim=0) # Shape [num_experts]
    f_i = tokens_per_expert / num_tokens
    
    # Calculate average routing probability per expert (P_i)
    # This is the mean of the gate weights for each expert
    mean_prob_per_expert = torch.mean(gate_weights, dim=0) # Shape [num_experts]
    P_i = mean_prob_per_expert
    
    # Calculate the loss: alpha * num_experts * sum(f_i * P_i)
    # alpha is a hyperparameter to scale the loss
    loss = num_experts * torch.sum(f_i * P_i)
    return loss

The total loss during training becomes:

Ltotal=Ltask+alphacdotLauxL_{total} = L_{task} + alpha cdot L_{aux}

Other Considerations

  • Capacity Factor: To manage computational load and memory, sometimes a 'capacity factor' CC is introduced. This limits the number of tokens each expert can process per batch to Cimes(exttokensperbatch/extnumexperts)C imes ( ext{tokens per batch} / ext{num experts}). Tokens exceeding this capacity for their chosen expert might be dropped or handled differently (e.g., routed to their second-choice expert if capacity allows). This adds complexity but helps stabilize training, especially in distributed settings.
  • Initialization: Careful initialization of the gating network can be important.
  • Numerical Stability: Using techniques like adding small Gaussian noise to the gating logits before the Top-K selection can sometimes improve load balancing and training stability.

Practical Example: Integrating MoE into a Transformer

Replacing the FFN layer in a standard Transformer block is straightforward.

class TransformerBlockWithMoE(nn.Module):
    def __init__(self, embed_dim, num_heads, num_experts, k=1, 
                 dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, 
                                             batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.moe_layer = MoELayer(embed_dim, embed_dim, 
                                  num_experts, k)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None): # mask for attention if needed
        # Multi-Head Attention part
        attn_output, _ = self.attention(x, x, x, attn_mask=mask)
        x = x + self.dropout(attn_output) # Residual connection
        x = self.norm1(x)
        
        # MoE Layer part
        moe_output, gate_weights = self.moe_layer(x)
        x = x + self.dropout(moe_output) # Residual connection
        x = self.norm2(x)
        
        # Return x and gate_weights for auxiliary loss calculation
        return x, gate_weights 

During the main training loop, you would collect the gate_weights from each MoE layer, calculate the calculate_load_balancing_loss, scale it, and add it to your primary task loss before backpropagation.

Conclusion

Implementing Mixture of Experts in PyTorch involves creating specialized expert networks and a gating mechanism to route inputs effectively. Essential steps include defining the expert and gating modules, assembling them into an MoELayer, and handling the token routing logic during the forward pass.

Most importantly, training MoE models requires managing expert load balance, typically by adding an auxiliary loss term. This ensures that all experts contribute meaningfully, enabling models to scale to billions or even trillions of parameters while maintaining computational efficiency.

While this guide provides a foundational implementation, production systems often employ more sophisticated routing strategies, custom kernels, and distributed training techniques to handle the scale and complexity of state-of-the-art MoE models. Nonetheless, the principles and PyTorch code presented here offer a solid starting point for incorporating MoE into your own network architectures.

© 2025 ApX Machine Learning. All rights reserved.

;