Having explored the theoretical underpinnings of various gating network designs, including top-k routing, noise injection, and architectural variations, we now turn to practical implementation. This section provides hands-on examples using PyTorch to construct custom gating mechanisms. Understanding how to translate these concepts into code is essential for building and experimenting with advanced Mixture of Experts models.
We will implement three common types of gating networks:
These examples assume familiarity with PyTorch fundamentals. We will focus specifically on the gating module itself, showing how it processes input tokens and produces routing decisions (expert indices and weights).
First, let's import the necessary libraries and define some configuration parameters we'll use throughout the examples.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# --- Configuration ---
model_dim = 512 # Dimension of input token representation
num_experts = 8 # Total number of experts
top_k = 2 # Number of experts to route each token to
batch_size = 4 # Example batch size
seq_len = 10 # Example sequence length
# Example input tensor (Batch Size, Sequence Length, Model Dimension)
input_tokens = torch.randn(batch_size, seq_len, model_dim)
This is the most common gating mechanism. It uses a single linear layer to project the input token dimension to the number of experts. Softmax is applied to get probabilities, and torch.topk
selects the experts with the highest probabilities.
class StandardTopKGating(nn.Module):
"""
Standard Top-k Gating Network.
Uses a linear layer and softmax to compute expert scores,
then selects the top-k experts based on these scores.
"""
def __init__(self, model_dim: int, num_experts: int, top_k: int):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
self.top_k = top_k
# Linear layer to map token embedding to expert scores
self.gate_proj = nn.Linear(self.model_dim, self.num_experts, bias=False)
print(f"Initialized StandardTopKGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}")
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating network.
Args:
x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- combined_weights (torch.Tensor): Routing weights for selected experts,
shape (Batch Size * Sequence Length, top_k).
- expert_indices (torch.Tensor): Indices of selected experts,
shape (Batch Size * Sequence Length, top_k).
- raw_logits (torch.Tensor): Raw logits output by the linear layer,
shape (Batch Size * Sequence Length, num_experts). Useful for auxiliary losses.
"""
# Reshape input for the linear layer: (B * S, D)
original_shape = x.shape
x = x.view(-1, self.model_dim) # Flatten batch and sequence dimensions
# Project input tokens to expert scores (logits)
# Shape: (B * S, num_experts)
raw_logits = self.gate_proj(x)
# Get top-k scores and indices using torch.topk
# top_k_logits shape: (B * S, top_k)
# top_k_indices shape: (B * S, top_k)
top_k_logits, top_k_indices = torch.topk(raw_logits, self.top_k, dim=-1)
# Apply softmax to the selected top-k logits to get weights
# Shape: (B * S, top_k)
combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float)
# Return the weights, indices, and the raw logits (needed for potential auxiliary losses)
return combined_weights, top_k_indices, raw_logits
# --- Instantiate and Test ---
standard_gating = StandardTopKGating(model_dim, num_experts, top_k)
weights, indices, logits = standard_gating(input_tokens)
print("\n--- Standard Top-k Gating Output ---")
print("Input Shape:", input_tokens.shape)
print("Combined Weights Shape:", weights.shape)
print("Expert Indices Shape:", indices.shape)
print("Raw Logits Shape:", logits.shape)
# Example output for one token
print("Example Weights (Token 0):", weights[0])
print("Example Indices (Token 0):", indices[0])
The output combined_weights
represents the normalized importance score for each chosen expert for each token. The expert_indices
tell us which experts were chosen. The raw_logits
are often used in calculating auxiliary load balancing losses, which we will discuss in the next chapter.
Flow diagram illustrating the Standard Top-k Gating mechanism. Input tokens are projected, top-k logits and indices are selected, and softmax is applied to the selected logits to produce routing weights.
Adding noise to the gating logits before the top-k selection is a technique used to encourage exploration during training and sometimes improve load balancing or robustness. Gaussian noise is commonly used.
class NoisyTopKGating(nn.Module):
"""
Noisy Top-k Gating Network.
Adds Gaussian noise to the logits before top-k selection during training.
"""
def __init__(self, model_dim: int, num_experts: int, top_k: int, noise_stddev=1.0):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
self.top_k = top_k
self.noise_stddev = noise_stddev
self.gate_proj = nn.Linear(self.model_dim, self.num_experts, bias=False)
# Layer for adding noise, only applied during training
self.noise_layer = nn.Linear(self.model_dim, self.num_experts, bias=False)
print(f"Initialized NoisyTopKGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}, NoiseStd={noise_stddev}")
def forward(self, x: torch.Tensor, is_training: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for the noisy gating network.
Args:
x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension)
is_training (bool): Flag indicating if the model is in training mode. Noise is only added during training.
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- combined_weights (torch.Tensor): Routing weights for selected experts.
- expert_indices (torch.Tensor): Indices of selected experts.
- raw_logits (torch.Tensor): Raw logits *before* noise addition.
"""
original_shape = x.shape
x = x.view(-1, self.model_dim)
# Get base logits
# Shape: (B * S, num_experts)
clean_logits = self.gate_proj(x)
if is_training:
# Calculate noise contribution
# We use a separate linear layer for noise magnitude, scaled by standard normal noise
# Shape: (B * S, num_experts)
noise_magnitude = self.noise_layer(x)
# Softplus ensures the magnitude scaling is positive
noise_scale = F.softplus(noise_magnitude)
# Sample standard Gaussian noise
# Shape: (B * S, num_experts)
sampled_noise = torch.randn_like(clean_logits) * self.noise_stddev
# Add scaled noise to the clean logits
noisy_logits = clean_logits + (noise_scale * sampled_noise)
else:
# No noise during inference
noisy_logits = clean_logits
# Select top-k based on (potentially noisy) logits
# top_k_logits shape: (B * S, top_k) - these are from the *noisy* logits if training
# top_k_indices shape: (B * S, top_k) - these are from the *noisy* logits if training
top_k_logits, top_k_indices = torch.topk(noisy_logits, self.top_k, dim=-1)
# Apply softmax to the selected top-k logits to get weights
# Shape: (B * S, top_k)
combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float)
# Return weights, indices, and the original *clean* logits for auxiliary loss calculation
return combined_weights, top_k_indices, clean_logits # Note: Return clean_logits
# --- Instantiate and Test ---
noisy_gating = NoisyTopKGating(model_dim, num_experts, top_k)
# Test in training mode
weights_train, indices_train, logits_train = noisy_gating(input_tokens, is_training=True)
print("\n--- Noisy Top-k Gating Output (Training) ---")
print("Input Shape:", input_tokens.shape)
print("Weights Shape (Train):", weights_train.shape)
print("Indices Shape (Train):", indices_train.shape)
print("Logits Shape (Train - Clean):", logits_train.shape)
print("Example Indices (Train - Token 0):", indices_train[0]) # May differ from standard due to noise
# Test in inference mode
weights_eval, indices_eval, logits_eval = noisy_gating(input_tokens, is_training=False)
print("\n--- Noisy Top-k Gating Output (Evaluation) ---")
print("Weights Shape (Eval):", weights_eval.shape)
print("Indices Shape (Eval):", indices_eval.shape) # Should match standard gating if weights are same
print("Example Indices (Eval - Token 0):", indices_eval[0])
# Check if eval indices match standard gating (assuming same weights init)
# Note: Due to floating point precision, small differences might occur.
# In a real scenario, weight initialization would be controlled.
# print("Eval indices match standard?", torch.allclose(indices_eval, indices)) # Needs same weight init
Notice that noise is only added during training (is_training=True
). During evaluation or inference, the behavior reverts to the standard top-k selection based on the clean logits. It's also important to return the clean logits for potential use in auxiliary loss calculations, as these reflect the router's underlying preference without the stochastic training noise. The specific noise implementation (e.g., using a separate learnable layer noise_layer
and softplus
for scaling) follows common practices seen in literature like the Switch Transformer.
While linear routers are common, sometimes a more expressive router using non-linearities can capture more complex routing patterns. Here's a simple example using a two-layer MLP with a ReLU activation.
class NonLinearGating(nn.Module):
"""
Non-Linear Top-k Gating Network using a simple MLP.
"""
def __init__(self, model_dim: int, num_experts: int, top_k: int, hidden_dim_multiplier=2):
super().__init__()
self.model_dim = model_dim
self.num_experts = num_experts
self.top_k = top_k
self.hidden_dim = model_dim * hidden_dim_multiplier
# Simple MLP: Linear -> ReLU -> Linear
self.mlp = nn.Sequential(
nn.Linear(self.model_dim, self.hidden_dim),
nn.ReLU(),
nn.Linear(self.hidden_dim, self.num_experts, bias=False)
)
print(f"Initialized NonLinearGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}, Hidden={self.hidden_dim}")
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for the non-linear gating network.
Args:
x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- combined_weights (torch.Tensor): Routing weights for selected experts.
- expert_indices (torch.Tensor): Indices of selected experts.
- raw_logits (torch.Tensor): Raw logits output by the MLP.
"""
original_shape = x.shape
x = x.view(-1, self.model_dim)
# Get logits from the MLP
# Shape: (B * S, num_experts)
raw_logits = self.mlp(x)
# Get top-k scores and indices
# top_k_logits shape: (B * S, top_k)
# top_k_indices shape: (B * S, top_k)
top_k_logits, top_k_indices = torch.topk(raw_logits, self.top_k, dim=-1)
# Apply softmax to the selected top-k logits
# Shape: (B * S, top_k)
combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float)
# Return weights, indices, and raw logits
return combined_weights, top_k_indices, raw_logits
# --- Instantiate and Test ---
nonlinear_gating = NonLinearGating(model_dim, num_experts, top_k)
weights_nl, indices_nl, logits_nl = nonlinear_gating(input_tokens)
print("\n--- Non-Linear (MLP) Gating Output ---")
print("Input Shape:", input_tokens.shape)
print("Weights Shape:", weights_nl.shape)
print("Indices Shape:", indices_nl.shape)
print("Logits Shape:", logits_nl.shape)
print("Example Indices (Token 0):", indices_nl[0])
This MLP router introduces more parameters and computation compared to the simple linear router. The choice between linear and non-linear routers depends on the specific task and dataset, often involving empirical validation. More complex router architectures, such as those incorporating attention mechanisms, could also be implemented following similar principles.
The primary outputs of any gating mechanism are the combined_weights
and expert_indices
. These are used within the full MoE layer to combine the outputs of the selected experts. While the full MoE layer implementation is beyond this specific section, the core idea is:
expert_indices
to identify which experts need to process which tokens. This often involves complex dispatching logic in distributed settings (covered in Chapter 4).combined_weights
to perform a weighted sum of the outputs from the selected experts for each token.For example (conceptually):
# Conceptual usage (simplified, assumes expert outputs are gathered)
# Assume 'expert_outputs' is a tensor where expert_outputs[i] is the output
# for the i-th token from *one* of its assigned experts.
# The full implementation requires handling multiple experts per token and gathering results.
# Simplified combination for a single token 't' assigned to experts 'e1' and 'e2'
# with weights 'w1' and 'w2':
# final_output_t = w1 * expert_output_t_e1 + w2 * expert_output_t_e2
raw_logits
output is crucial for implementing auxiliary loss functions (Chapter 3) designed to prevent expert load imbalance, a common challenge in MoE training.k
(number of experts per token) affects computational load and model capacity. k=1
or k=2
are common starting points.This practical exercise demonstrates how to implement different gating strategies. By modifying these building blocks, you can experiment with various architectural ideas discussed in this chapter, tailoring the routing mechanism to the specific needs of your MoE model. The next chapter will focus on the training dynamics, particularly how to use outputs like raw_logits
to ensure stable and balanced training.
© 2025 ApX Machine Learning