Masterclass
As discussed in the introduction to this chapter, scaling large language models often involves trade-offs between model capacity and computational cost. Standard Transformer models grow denser as they increase in size. Every parameter participates in the computation for every input token during a forward pass. While increasing parameters generally improves performance (as seen in scaling laws), the computational requirement, measured in Floating Point Operations per Second (FLOPs), scales proportionally. This presents a significant barrier when aiming for models with trillions of parameters.
Mixture-of-Experts (MoE) offers an alternative scaling paradigm based on conditional computation. Instead of having one monolithic feed-forward network (FFN) layer within each Transformer block, an MoE layer replaces it with multiple parallel FFN "expert" networks. For each input token, a routing mechanism, often called a "gating network," dynamically selects a small subset of these experts (e.g., the top 1 or top 2) to process the token.
A simplified view of an MoE layer. An input token is directed by the router to a small subset of available expert networks. Only the selected experts perform computations for that token.
The main idea is sparsity. Although the total number of parameters in the MoE layer (sum of parameters in the router and all experts) can be vastly larger than in a standard FFN layer, the number of parameters used to process a single token remains relatively small, determined by the number of activated experts.
Let's consider a typical Transformer block's FFN layer. If the hidden dimension is dmodel​ and the FFN inner dimension is dff​, the computation involves approximately 2×dmodel​×dff​ FLOPs.
In an MoE layer with N experts, each having the same dimensions as the original FFN, and a router selecting the top k experts (where k is much smaller than N, often k=1 or k=2), the computational cost per token is roughly:
FLOPsMoE​≈FLOPsRouter​+k×(2×dmodel​×dff​)The router's computation is usually negligible compared to the experts'. If k=1, the FLOPs per token are similar to the original dense FFN layer, even though the total parameter count has increased by a factor of roughly N. This allows for building models with potentially trillions of parameters while maintaining a manageable computational budget during training and inference.
Here's a structure in PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Expert(nn.Module):
"""A simple Feed-Forward Network expert"""
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.activation = nn.ReLU() # Or GeLU, SwiGLU etc.
def forward(self, x):
return self.fc2(self.activation(self.fc1(x)))
class MoELayer(nn.Module):
"""Mixture-of-Experts Layer"""
def __init__(self, d_model, d_ff, num_experts, top_k):
super().__init__()
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
# Pool of experts
self.experts = nn.ModuleList(
[Expert(d_model, d_ff) for _ in range(num_experts)]
)
# Gating network (learns to route tokens to experts)
# Input: token representation (d_model),
# Output: scores for each expert (num_experts)
self.gating_network = nn.Linear(d_model, num_experts)
def forward(self, x):
# x shape: (sequence_length, batch_size, d_model)
# or (batch_size, sequence_length, d_model)
# Assume shape (batch_size, sequence_length, d_model)
# for simplicity here
batch_size, seq_len, d_model = x.shape
# Reshape to (batch*seq_len, d_model)
x_reshaped = x.view(-1, d_model)
# 1. Get routing weights from the gating network
# logits shape: (batch*seq_len, num_experts)
router_logits = self.gating_network(x_reshaped)
# Softmax over experts
routing_weights = F.softmax(router_logits, dim=1)
# 2. Select top-k experts and get their weights
# top_k_weights, top_k_indices shape:
# (batch*seq_len, top_k)
top_k_weights, top_k_indices = torch.topk(
routing_weights, self.top_k, dim=1
)
# Normalize the top-k weights
top_k_weights_norm = (
top_k_weights / top_k_weights.sum(dim=1, keepdim=True)
)
# 3. Compute expert outputs (simplified -
# actual implementations are more complex for efficiency)
# Initialize final output tensor
final_output = torch.zeros_like(x_reshaped)
# This is a simplified loop,
# efficient implementations use scatter/gather ops
for i in range(batch_size * seq_len):
token_input = x_reshaped[i]
for k in range(self.top_k):
expert_idx = top_k_indices[i, k].item()
expert_weight = top_k_weights_norm[i, k]
# Compute output of the selected expert
expert_output = self.experts[expert_idx](token_input)
# Accumulate weighted output
final_output[i] += expert_weight * expert_output
# Reshape back
return final_output.view(
batch_size, seq_len, d_model
) # Reshape back
Note: The forward
method shown above uses a simple loop for clarity. Practical implementations rely on optimized operations to handle the sparse routing efficiently across batches and devices, avoiding explicit loops over tokens.
The primary benefit of the MoE approach is this decoupling of parameter count from computation. It allows models to significantly increase their capacity, potentially leading to better performance and knowledge representation, without a proportional increase in the FLOPs required for processing each token. Furthermore, experts can potentially specialize in processing different types of inputs or linguistic phenomena, although verifying and controlling such specialization remains an area of active research.
However, MoE introduces its own set of challenges, particularly in training dynamics and implementation complexity. Ensuring that the router effectively distributes tokens across experts (load balancing) and managing the communication overhead in distributed settings are important considerations that we will discuss in the following sections.
© 2025 ApX Machine Learning