This practical exercise implements several advanced routing mechanisms. It is designed to solidify understanding of how each router operates, its specific implementation details, and the trade-offs involved. Each router will be constructed as a modular PyTorch nn.Module, making them interchangeable within a larger MoE layer.
For this exercise, assume we are working with a batch of token embeddings. We will define a standard set of dimensions to ensure our examples are consistent.
import torch
import torch.nn as nn
import torch.nn.functional as F
# --- Configuration ---
NUM_EXPERTS = 8
D_MODEL = 512
BATCH_SIZE = 4
SEQ_LEN = 1024
TOKENS_PER_BATCH = BATCH_SIZE * SEQ_LEN
Our goal is to create routers that take a tensor of shape (TOKENS_PER_BATCH, D_MODEL) and produce the necessary assignments and auxiliary losses for an MoE layer.
The Noisy Top-k router is a common and effective strategy for improving load balancing during training. It works by adding random noise to the router's logits before selecting the top-k experts. This stochasticity helps prevent the router from consistently favoring the same few experts.
The noise is typically drawn from a normal distribution and scaled by a learned weight matrix. This noise is only applied during training.
logits=Linear(x) noisy_logits=logits+torch.randn_like(logits)⋅Softplus(noise_weight)Let's implement this. We will select k=2 experts for each token.
class NoisyTopKRouter(nn.Module):
def __init__(self, d_model, num_experts, top_k=2):
super().__init__()
self.top_k = top_k
self.num_experts = num_experts
# Layer to generate logits
self.gate = nn.Linear(d_model, num_experts)
# Layer to generate noise scaling factor
self.noise_net = nn.Linear(d_model, num_experts)
def forward(self, x):
# x shape: [TOKENS_PER_BATCH, d_model]
logits = self.gate(x)
# Add noise during training
if self.training:
noise_logits = self.noise_net(x)
# Use softplus to ensure the noise scaling factor is positive
noise = torch.randn_like(logits) * F.softplus(noise_logits)
logits = logits + noise
# Get the top-k logits and their indices
# top_k_logits shape: [TOKENS_PER_BATCH, top_k]
# top_k_indices shape: [TOKENS_PER_BATCH, top_k]
top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1)
# Apply softmax to the top-k logits to get the weights
gating_scores = F.softmax(top_k_logits, dim=-1)
# We also need a mask for the load balancing loss
# This mask has a 1 for each token-expert pair that was selected
# B = TOKENS_PER_BATCH, E = NUM_EXPERTS
# zero_mask shape: [B, E]
zero_mask = torch.zeros(x.size(0), self.num_experts, device=x.device)
# router_mask shape: [B, E]
router_mask = zero_mask.scatter(1, top_k_indices, 1)
return top_k_indices, gating_scores, router_mask
# --- Example Usage ---
noisy_router = NoisyTopKRouter(D_MODEL, NUM_EXPERTS, top_k=2)
noisy_router.train() # Set to training mode to enable noise
input_tokens = torch.randn(TOKENS_PER_BATCH, D_MODEL)
indices, scores, mask = noisy_router(input_tokens)
print("Noisy Top-k Router Output:")
print("Indices Shape:", indices.shape) # [TOKENS_PER_BATCH, 2]
print("Scores Shape:", scores.shape) # [TOKENS_PER_BATCH, 2]
print("Mask Shape:", mask.shape) # [TOKENS_PER_BATCH, 8]
The output gives us exactly what a subsequent MoE layer needs: which experts to route to (indices), how to weight their outputs (scores), and a mask to help compute the load balancing loss.
The Switch Transformer architecture simplifies routing by setting k=1. This design choice significantly reduces communication costs and computational complexity, as each token is processed by only a single expert. The load balancing loss becomes even more important in this setup to prevent expert under-utilization.
The implementation is a direct simplification of our NoisyTopKRouter, with top_k fixed to 1 and the noise mechanism removed for clarity, though it can also be included.
class SwitchRouter(nn.Module):
def __init__(self, d_model, num_experts):
super().__init__()
self.num_experts = num_experts
self.gate = nn.Linear(d_model, num_experts)
def forward(self, x):
# x shape: [TOKENS_PER_BATCH, d_model]
logits = self.gate(x) # [B, E]
# Apply softmax to get probabilities for the load balancing loss
router_probs = F.softmax(logits, dim=-1)
# Select the single best expert
# top_1_scores is equivalent to the max logit
# top_1_indices is the index of the selected expert
top_1_scores, top_1_indices = torch.max(router_probs, dim=-1)
# Create a one-hot mask for the selected experts
# This mask is used both for routing and for calculating the loss
# one_hot_mask shape: [B, E]
one_hot_mask = F.one_hot(top_1_indices, num_classes=self.num_experts)
# The gating score is just 1.0 for the selected expert.
# We return it in a shape consistent with the top-k router.
gating_scores = top_1_scores.unsqueeze(-1)
return top_1_indices.unsqueeze(-1), gating_scores, one_hot_mask
# --- Example Usage ---
switch_router = SwitchRouter(D_MODEL, NUM_EXPERTS)
switch_router.eval() # No difference between train/eval in this simple version
input_tokens = torch.randn(TOKENS_PER_BATCH, D_MODEL)
indices, scores, mask = switch_router(input_tokens)
print("\nSwitch Router (Top-1) Output:")
print("Indices Shape:", indices.shape) # [TOKENS_PER_BATCH, 1]
print("Scores Shape:", scores.shape) # [TOKENS_PER_BATCH, 1]
print("Mask Shape:", mask.shape) # [TOKENS_PER_BATCH, 8]
Notice the output shapes are consistent with our previous router, ensuring modularity. The mask is now a one-hot vector for each token, reflecting the k=1 routing decision.
The fundamental difference between routing strategies lies in how they assign tokens to experts. A "hard" routing mechanism like a Switch Router makes a discrete choice, while a "soft" mechanism creates a weighted blend. The diagram below illustrates this distinction.
Hard routing sends a token to a discrete set of experts. Soft routing computes a weighted average from all experts, creating a blended output.
A primary motivation for using noisy routing is to improve load balance. Without it, a standard router might persistently send most tokens to a few "popular" experts, leaving others under-trained. Noise encourages exploration, spreading the load more evenly.
The chart below shows a distribution of tokens across 8 experts for a standard Top-2 router versus a Noisy Top-2 router.
A Noisy Top-k router typically results in a more uniform load distribution compared to a standard router, which can suffer from severe imbalance where some experts receive a disproportionate number of tokens.
These modular routers can be easily slotted into a complete MoE layer. The MoE layer's responsibility is to use the router's output to perform the sparse computation. Here is a skeleton of how our NoisyTopKRouter would be used.
class MoELayer(nn.Module):
def __init__(self, d_model, num_experts, top_k):
super().__init__()
self.router = NoisyTopKRouter(d_model, num_experts, top_k)
# self.experts = nn.ModuleList([...]) # A list of expert networks
# ...
def forward(self, x):
# 1. Get routing assignments
# indices, scores, mask = self.router(x)
# 2. Perform dispatch/gather operation
# This is a complex step involving permuting tokens based on 'indices'
# so that each expert receives a batch of its assigned tokens.
# 3. Compute expert outputs in parallel
# expert_outputs = ...
# 4. Combine expert outputs using 'scores'
# final_output = ...
# 5. Calculate and return load balancing loss using 'mask'
# load_balancing_loss = ...
# return final_output, load_balancing_loss
pass
This hands-on exercise demonstrates that while the strategies differ, their implementations can be contained within a clean, modular interface. The choice of router is a critical design decision that you can now analyze not just in theory, but with a clear understanding of the underlying code. The next chapter on distributed training will show how these routing decisions interact with system-level parallelism to enable massive model scale.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with