Translating the architecture of a Mixture of Experts layer into a practical PyTorch implementation involves constructing a self-contained nn.Module that encapsulates the gating and expert logic. For this initial implementation, we will focus on a simple yet effective top-1 routing strategy, where each token is routed to a single, most qualified expert. This provides a clear foundation before we address more complex top-k and switch-style mechanisms in later chapters.
We will build the MoE layer from three primary components:
In most Transformer-based MoE models, the "experts" are simply feed-forward networks (FFNs). Each expert is structurally identical but will learn a different function during training due to the specialized data it receives. Let's define a simple Expert module. It contains two linear layers with a GELU activation function in between, a common pattern in modern Transformers.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""
A simple feed-forward network expert.
It processes an input tensor of shape (..., d_model) and returns an output
of the same shape.
"""
def __init__(self, d_model: int, d_hidden: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_hidden),
nn.GELU(),
nn.Linear(d_hidden, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
This Expert class is a standard building block. The interesting part of an MoE layer is not the experts themselves, but how the model routes information to them.
Now we will assemble the main MoELayer class. This class will contain the gating network and a list of expert modules. The implementation of the forward method is where the core logic of token dispatching resides.
The process for a forward pass is as follows:
(batch_size, sequence_length, d_model).(batch_size * sequence_length, d_model) to treat each token independently.softmax function to the logits to obtain the gating weights, g(x).top-1 router, identify the expert with the highest score for each token.top-1 routing, the output is simply the output of the chosen expert, weighted by its gating score.Data flow for a
top-1MoE layer. Tokens are passed to a gating network, which selects a single expert for processing. The results are then gathered to form the final output.
Here is the implementation of the MoELayer class. Pay close attention to the comments in the forward method, which explain the dispatch mechanism.
class MoELayer(nn.Module):
"""
A Mixture of Experts layer.
Args:
d_model (int): The dimension of the input and output.
num_experts (int): The total number of experts.
d_hidden (int): The hidden dimension of each expert's FFN.
top_k (int): The number of experts to route each token to. Currently, only top_k=1 is supported.
"""
def __init__(self, d_model: int, num_experts: int, d_hidden: int, top_k: int = 1):
super().__init__()
if top_k != 1:
raise ValueError("This basic implementation only supports top_k=1")
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
# Gating network
self.gate = nn.Linear(d_model, num_experts)
# Expert networks
self.experts = nn.ModuleList([Expert(d_model, d_hidden) for _ in range(num_experts)])
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, d_model)
Returns:
torch.Tensor: Output tensor of shape (batch_size, seq_len, d_model)
"""
batch_size, seq_len, d_model = x.shape
# Reshape for token-level processing
x = x.view(-1, d_model) # (batch_size * seq_len, d_model)
# 1. Get gating logits and weights
router_logits = self.gate(x)
gating_weights = F.softmax(router_logits, dim=1)
# 2. Select top-1 expert for each token
# topk returns a tuple of (values, indices)
top_k_weights, top_k_indices = torch.topk(gating_weights, self.top_k, dim=1)
# For top-1, top_k_indices is (num_tokens, 1). We squeeze it.
expert_indices = top_k_indices.squeeze(1)
# 3. Create a final output tensor, initialized to zeros
final_output = torch.zeros_like(x)
# 4. Dispatch tokens to their selected experts
# This is a simple, non-performant way to do it.
# In practice, this dispatch and combine step is heavily optimized.
for i in range(self.num_experts):
# Find all tokens routed to this expert
token_mask = (expert_indices == i)
# If no tokens are routed, continue
if token_mask.sum() == 0:
continue
# Get the tokens for the current expert
selected_tokens = x[token_mask]
# Process tokens through the expert
expert_output = self.experts[i](selected_tokens)
# Get the corresponding gating weights for scaling the output
# For top-1, we can use the top_k_weights directly, as it has shape (num_tokens, 1)
# We select the weights for the tokens that were routed to this expert
gating_scores = top_k_weights[token_mask]
# Place the scaled expert output back into the final output tensor
final_output[token_mask] = expert_output * gating_scores
# Reshape back to original dimensions
return final_output.view(batch_size, seq_len, d_model)
Let's test our implementation with some dummy data to ensure it works as expected. We will instantiate the MoELayer and pass a random tensor through it.
# Configuration
batch_size = 4
seq_len = 16
d_model = 128
num_experts = 8
d_hidden = 512 # Hidden dimension of each expert FFN
# Create a random input tensor
input_tensor = torch.randn(batch_size, seq_len, d_model)
# Instantiate the MoE layer
moe_layer = MoELayer(d_model=d_model, num_experts=num_experts, d_hidden=d_hidden, top_k=1)
# Perform a forward pass
output_tensor = moe_layer(input_tensor)
# Print shapes to verify
print("Input shape:", input_tensor.shape)
print("Output shape:", output_tensor.shape)
# Check that the output shape is correct
assert input_tensor.shape == output_tensor.shape
Running this code will produce the following output, confirming that the layer correctly processes the input and returns a tensor of the same dimensions.
Input shape: torch.Size([4, 16, 128])
Output shape: torch.Size([4, 16, 128])
This implementation provides a functional top-1 MoE layer and demonstrates the fundamental mechanics of sparse routing. However, it is important to recognize its limitations, which highlight the complexities of building these models for production environments:
for loop that iterates through each expert is simple to understand but highly inefficient. It does not parallelize the expert computations well and introduces significant overhead. High-performance MoE implementations use optimized kernels to perform this dispatch-and-gather operation efficiently on GPUs.forward method would need to return the router_logits and gating_weights so a higher-level training loop could compute this loss and add it to the main task loss. This is an important component for stable training.top-1: Our code is hardcoded for top_k=1. Extending it to top_k > 1 would require a more sophisticated method for combining the outputs of multiple experts for a single token.This hands-on exercise serves as a solid starting point. You now have a working model of an MoE layer that you can build upon. In the following chapters, we will address these limitations by exploring advanced routing strategies, optimization techniques for training, and methods for efficient, large-scale deployment.
Was this section helpful?
torch.nn module, which is used for building neural network layers and models such as the Expert and MoELayer classes. Provides detailed API reference and usage examples.© 2026 ApX Machine LearningEngineered with