You will modify a standard Transformer encoder block, replacing its dense feed-forward network (FFN) with a Mixture of Experts (MoE) layer. This is the most common method for integrating MoE into existing architectures and directly demonstrates the principle of increasing model parameters while managing computational load.
We will proceed in four steps:
MoETransformerEncoderLayer that substitutes the FFN with our MoE layer.For this exercise, we will use PyTorch. A strong familiarity with torch.nn.Module is assumed.
First, let's examine a typical feed-forward network inside a Transformer block. It usually consists of two linear layers with a non-linear activation function in between. This FFN is applied independently to each token representation after the self-attention mechanism.
Here is a simplified implementation of the FFN portion of a TransformerEncoderLayer:
import torch
import torch.nn as nn
import torch.nn.functional as F
class StandardFFN(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.dropout(F.relu(self.linear1(x))))
return x
This StandardFFN module is our target for replacement. It is a dense operation; every input token is processed through the same set of weights in linear1 and linear2.
Now, we will build the MoE layer that will serve as the sparse replacement for the StandardFFN. Our implementation will feature top-2 gating, which means each token will be routed to two experts. This design follows the principles discussed in Chapter 2.
The MoE layer requires three main components:
class Expert(nn.Module):
"""A simple feed-forward network expert."""
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class MoELayer(nn.Module):
def __init__(self, d_model: int, d_ff: int, n_experts: int, k: int = 2):
super().__init__()
self.d_model = d_model
self.n_experts = n_experts
self.k = k
# A list of N expert networks
self.experts = nn.ModuleList([Expert(d_model, d_ff) for _ in range(n_experts)])
# Gating network
self.gate = nn.Linear(d_model, n_experts)
# Auxiliary loss coefficient
self.aux_loss_coef = 0.01
def forward(self, x: torch.Tensor):
batch_size, seq_len, _ = x.shape
x_flat = x.view(-1, self.d_model) # Reshape to (batch_size * seq_len, d_model)
num_tokens = x_flat.shape[0]
# 1. Get gating scores and select top-k experts
gate_logits = self.gate(x_flat)
gate_scores = F.softmax(gate_logits, dim=-1)
# Find top-k experts for each token
top_k_scores, top_k_indices = torch.topk(gate_scores, self.k, dim=-1)
# Normalize the top-k scores to sum to 1
top_k_scores = top_k_scores / top_k_scores.sum(dim=-1, keepdim=True)
# 2. Calculate auxiliary load balancing loss
# This loss encourages the gating network to distribute tokens evenly
tokens_per_expert = F.one_hot(top_k_indices, self.n_experts).sum(0)
load_balancing_loss = self.aux_loss_coef * (self.n_experts / num_tokens) * torch.sum(tokens_per_expert * gate_scores.mean(0))
# 3. Dispatch tokens to experts and combine results
final_output = torch.zeros_like(x_flat)
# Create a flat index for efficient expert processing
flat_top_k_indices = top_k_indices.flatten()
# Create a combined batch for all experts to process in parallel
# This is a simplification; real systems use more complex dispatching
# Here we iterate for clarity, but this can be optimized.
for i in range(self.n_experts):
# Find which tokens are routed to this expert
token_indices_for_expert = torch.where(top_k_indices == i)[0]
if token_indices_for_expert.numel() > 0:
# Get the gating scores for these tokens
gating_values = gate_scores[token_indices_for_expert, i]
# Process tokens with the expert
expert_output = self.experts[i](x_flat[token_indices_for_expert])
# Weight the expert output by its gating score
final_output[token_indices_for_expert] += expert_output * gating_values.unsqueeze(-1)
return final_output.view(batch_size, seq_len, -1), load_balancing_loss
Note: The dispatch logic in the code above uses a loop for clarity. In high-performance implementations like those discussed in Chapter 3, this token-to-expert dispatch is a highly optimized operation, often handled by custom CUDA kernels or specialized libraries to avoid explicit loops and minimize data shuffling.
With the MoELayer defined, we can now create a new MoETransformerEncoderLayer. Its structure mirrors the standard layer, but it uses MoELayer instead of StandardFFN. A significant difference is that its forward method must now also return the auxiliary loss generated by the MoE layer.
Diagram comparing a standard FFN block with a sparse MoE block. In the MoE version, the Gating Network selects a subset of experts (in this case, Expert 2 is active) to process the input, while others (Expert 1, Expert N) remain inactive.
Here is the implementation of the MoE-enabled encoder layer:
class MoETransformerEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, d_ff: int, n_experts: int, dropout: float = 0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
# Replace the standard FFN with our MoE layer
self.moe_layer = MoELayer(d_model=d_model, d_ff=d_ff, n_experts=n_experts, k=2)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, src: torch.Tensor, src_mask: torch.Tensor = None, src_key_padding_mask: torch.Tensor = None):
# Self-Attention block
attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(attn_output)
src = self.norm1(src)
# MoE block
moe_output, aux_loss = self.moe_layer(src)
src = src + self.dropout2(moe_output)
src = self.norm2(src)
return src, aux_loss
Notice that the forward method now returns two values: the final processed tensor src and the aux_loss from the MoE layer. This is a critical modification.
A model built with MoETransformerEncoderLayer will produce an auxiliary loss at each MoE layer. These losses must be collected and added to the primary task loss (e.g., cross-entropy) during training. This encourages the model to learn effective routing strategies alongside its main task.
A typical training loop needs to be updated to handle this.
# --- Assume model, dataloader, optimizer, and criterion are defined ---
# model = nn.TransformerEncoder(
# MoETransformerEncoderLayer(d_model=512, nhead=8, d_ff=2048, n_experts=8),
# num_layers=6
# )
for data, targets in dataloader:
optimizer.zero_grad()
# The model forward pass now returns output and a list/tuple of auxiliary losses
# We need to properly handle the stacked encoder layers' outputs
# Let's assume a wrapper around nn.TransformerEncoder that collects aux losses
# --- Simplified representation of a forward pass ---
# In a real implementation, you would need to iterate through layers
# or have the model accumulate the losses internally.
# A simplified model definition for this training example:
class MoETransformer(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList([
MoETransformerEncoderLayer(d_model=512, nhead=8, d_ff=2048, n_experts=8)
for _ in range(6)
])
def forward(self, x):
total_aux_loss = 0.0
for layer in self.layers:
x, aux_loss = layer(x)
total_aux_loss += aux_loss
return x, total_aux_loss
# model = MoETransformer()
# output, total_aux_loss = model(data)
# Let's assume `output` and `total_aux_loss` are correctly retrieved
# output, total_aux_loss = model(data)
# For a simple single layer example:
layer = MoETransformerEncoderLayer(d_model=512, nhead=8, d_ff=2048, n_experts=8)
output, aux_loss = layer(data)
# 1. Calculate the primary task loss
main_loss = criterion(output.view(-1, vocab_size), targets.view(-1))
# 2. Combine with the auxiliary loss from the MoE layer
# The aux_loss is already scaled by its coefficient in our MoELayer
total_loss = main_loss + aux_loss
# Backpropagate the combined loss
total_loss.backward()
optimizer.step()
By adding the aux_loss to the main_loss, we create a gradient signal that simultaneously optimizes the model for its primary objective and for balanced routing across its experts. Without this auxiliary loss, the gating network would likely converge to a state where it always picks the same few experts, leading to the "expert collapse" problem discussed in Chapter 1.
This hands-on modification completes the process from a dense FFN to a functional, sparse MoE block within a Transformer. You have now implemented the core architectural change that enables MoE models to scale their parameter counts far beyond their dense counterparts, a foundation upon which the large-scale training and inference techniques in the following chapters are built.
Was this section helpful?
torch.nn - PyTorch 2.3 documentation, PyTorch Core Team, 2024 (PyTorch) - Official documentation for the torch.nn module, which is fundamental for building neural network components in PyTorch as demonstrated in the exercise.© 2026 ApX Machine LearningEngineered with