Having explored the theoretical underpinnings of various gating network designs, including top-k routing, noise injection, and architectural variations, we now turn to practical implementation. This section provides hands-on examples using PyTorch to construct custom gating mechanisms. Understanding how to translate these concepts into code is essential for building and experimenting with advanced Mixture of Experts models.We will implement three common types of gating networks:A standard Top-k Gating network using a linear layer and softmax.A Noisy Top-k Gating network, which adds noise to the router logits before selection.A simple Non-Linear Gating network using a small Multi-Layer Perceptron (MLP).These examples assume familiarity with PyTorch fundamentals. We will focus specifically on the gating module itself, showing how it processes input tokens and produces routing decisions (expert indices and weights).SetupFirst, let's import the necessary libraries and define some configuration parameters we'll use throughout the examples.import torch import torch.nn as nn import torch.nn.functional as F import math # --- Configuration --- model_dim = 512 # Dimension of input token representation num_experts = 8 # Total number of experts top_k = 2 # Number of experts to route each token to batch_size = 4 # Example batch size seq_len = 10 # Example sequence length # Example input tensor (Batch Size, Sequence Length, Model Dimension) input_tokens = torch.randn(batch_size, seq_len, model_dim)1. Standard Top-k GatingThis is the most common gating mechanism. It uses a single linear layer to project the input token dimension to the number of experts. Softmax is applied to get probabilities, and torch.topk selects the experts with the highest probabilities.class StandardTopKGating(nn.Module): """ Standard Top-k Gating Network. Uses a linear layer and softmax to compute expert scores, then selects the top-k experts based on these scores. """ def __init__(self, model_dim: int, num_experts: int, top_k: int): super().__init__() self.model_dim = model_dim self.num_experts = num_experts self.top_k = top_k # Linear layer to map token embedding to expert scores self.gate_proj = nn.Linear(self.model_dim, self.num_experts, bias=False) print(f"Initialized StandardTopKGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}") def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass for the gating network. Args: x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - combined_weights (torch.Tensor): Routing weights for selected experts, shape (Batch Size * Sequence Length, top_k). - expert_indices (torch.Tensor): Indices of selected experts, shape (Batch Size * Sequence Length, top_k). - raw_logits (torch.Tensor): Raw logits output by the linear layer, shape (Batch Size * Sequence Length, num_experts). Useful for auxiliary losses. """ # Reshape input for the linear layer: (B * S, D) original_shape = x.shape x = x.view(-1, self.model_dim) # Flatten batch and sequence dimensions # Project input tokens to expert scores (logits) # Shape: (B * S, num_experts) raw_logits = self.gate_proj(x) # Get top-k scores and indices using torch.topk # top_k_logits shape: (B * S, top_k) # top_k_indices shape: (B * S, top_k) top_k_logits, top_k_indices = torch.topk(raw_logits, self.top_k, dim=-1) # Apply softmax to the selected top-k logits to get weights # Shape: (B * S, top_k) combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float) # Return the weights, indices, and the raw logits (needed for potential auxiliary losses) return combined_weights, top_k_indices, raw_logits # --- Instantiate and Test --- standard_gating = StandardTopKGating(model_dim, num_experts, top_k) weights, indices, logits = standard_gating(input_tokens) print("\n--- Standard Top-k Gating Output ---") print("Input Shape:", input_tokens.shape) print("Combined Weights Shape:", weights.shape) print("Expert Indices Shape:", indices.shape) print("Raw Logits Shape:", logits.shape) # Example output for one token print("Example Weights (Token 0):", weights[0]) print("Example Indices (Token 0):", indices[0])The output combined_weights represents the normalized importance score for each chosen expert for each token. The expert_indices tell us which experts were chosen. The raw_logits are often used in calculating auxiliary load balancing losses, which we will discuss in the next chapter.digraph G { rankdir=LR; node [shape=box, style=rounded, fontname="sans-serif", margin="0.2,0.1"]; edge [fontname="sans-serif", fontsize=10]; Input [label="Input Token\n(Batch*Seq, Dim)", shape= Mrecord, color="#495057", style=filled, fillcolor="#dee2e6"]; Linear [label="Linear Layer\n(Dim -> NumExperts)", color="#4263eb", style=filled, fillcolor="#bac8ff"]; Logits [label="Raw Logits\n(Batch*Seq, NumExperts)", shape= Mrecord, color="#495057", style=filled, fillcolor="#dee2e6"]; TopK [label="Top-K Selection", color="#ae3ec9", style=filled, fillcolor="#eebefa"]; TopKLogits [label="Top-K Logits\n(Batch*Seq, TopK)", shape= Mrecord, color="#495057", style=filled, fillcolor="#dee2e6"]; TopKIndices [label="Top-K Indices\n(Batch*Seq, TopK)", shape= Mrecord, color="#ae3ec9", style=filled, fillcolor="#eebefa"]; Softmax [label="Softmax", color="#1098ad", style=filled, fillcolor="#99e9f2"]; Weights [label="Combined Weights\n(Batch*Seq, TopK)", shape= Mrecord, color="#1098ad", style=filled, fillcolor="#99e9f2"]; Input -> Linear; Linear -> Logits; Logits -> TopK [label=" Select K largest"]; TopK -> TopKLogits; TopK -> TopKIndices; TopKLogits -> Softmax; Softmax -> Weights; { rank=same; Input; } { rank=same; Linear; } { rank=same; Logits; } { rank=same; TopK; } { rank=same; TopKLogits; TopKIndices;} { rank=same; Softmax;} { rank=same; Weights;} }Flow diagram illustrating the Standard Top-k Gating mechanism. Input tokens are projected, top-k logits and indices are selected, and softmax is applied to the selected logits to produce routing weights.2. Noisy Top-k GatingAdding noise to the gating logits before the top-k selection is a technique used to encourage exploration during training and sometimes improve load balancing or robustness. Gaussian noise is commonly used.class NoisyTopKGating(nn.Module): """ Noisy Top-k Gating Network. Adds Gaussian noise to the logits before top-k selection during training. """ def __init__(self, model_dim: int, num_experts: int, top_k: int, noise_stddev=1.0): super().__init__() self.model_dim = model_dim self.num_experts = num_experts self.top_k = top_k self.noise_stddev = noise_stddev self.gate_proj = nn.Linear(self.model_dim, self.num_experts, bias=False) # Layer for adding noise, only applied during training self.noise_layer = nn.Linear(self.model_dim, self.num_experts, bias=False) print(f"Initialized NoisyTopKGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}, NoiseStd={noise_stddev}") def forward(self, x: torch.Tensor, is_training: bool = True) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass for the noisy gating network. Args: x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension) is_training (bool): Flag indicating if the model is in training mode. Noise is only added during training. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - combined_weights (torch.Tensor): Routing weights for selected experts. - expert_indices (torch.Tensor): Indices of selected experts. - raw_logits (torch.Tensor): Raw logits *before* noise addition. """ original_shape = x.shape x = x.view(-1, self.model_dim) # Get base logits # Shape: (B * S, num_experts) clean_logits = self.gate_proj(x) if is_training: # Calculate noise contribution # We use a separate linear layer for noise magnitude, scaled by standard normal noise # Shape: (B * S, num_experts) noise_magnitude = self.noise_layer(x) # Softplus ensures the magnitude scaling is positive noise_scale = F.softplus(noise_magnitude) # Sample standard Gaussian noise # Shape: (B * S, num_experts) sampled_noise = torch.randn_like(clean_logits) * self.noise_stddev # Add scaled noise to the clean logits noisy_logits = clean_logits + (noise_scale * sampled_noise) else: # No noise during inference noisy_logits = clean_logits # Select top-k based on (potentially noisy) logits # top_k_logits shape: (B * S, top_k) - these are from the *noisy* logits if training # top_k_indices shape: (B * S, top_k) - these are from the *noisy* logits if training top_k_logits, top_k_indices = torch.topk(noisy_logits, self.top_k, dim=-1) # Apply softmax to the selected top-k logits to get weights # Shape: (B * S, top_k) combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float) # Return weights, indices, and the original *clean* logits for auxiliary loss calculation return combined_weights, top_k_indices, clean_logits # Note: Return clean_logits # --- Instantiate and Test --- noisy_gating = NoisyTopKGating(model_dim, num_experts, top_k) # Test in training mode weights_train, indices_train, logits_train = noisy_gating(input_tokens, is_training=True) print("\n--- Noisy Top-k Gating Output (Training) ---") print("Input Shape:", input_tokens.shape) print("Weights Shape (Train):", weights_train.shape) print("Indices Shape (Train):", indices_train.shape) print("Logits Shape (Train - Clean):", logits_train.shape) print("Example Indices (Train - Token 0):", indices_train[0]) # May differ from standard due to noise # Test in inference mode weights_eval, indices_eval, logits_eval = noisy_gating(input_tokens, is_training=False) print("\n--- Noisy Top-k Gating Output (Evaluation) ---") print("Weights Shape (Eval):", weights_eval.shape) print("Indices Shape (Eval):", indices_eval.shape) # Should match standard gating if weights are same print("Example Indices (Eval - Token 0):", indices_eval[0]) # Check if eval indices match standard gating (assuming same weights init) # Note: Due to floating point precision, small differences might occur. # In a real scenario, weight initialization would be controlled. # print("Eval indices match standard?", torch.allclose(indices_eval, indices)) # Needs same weight initNotice that noise is only added during training (is_training=True). During evaluation or inference, the behavior reverts to the standard top-k selection based on the clean logits. It's also important to return the clean logits for potential use in auxiliary loss calculations, as these reflect the router's underlying preference without the stochastic training noise. The specific noise implementation (e.g., using a separate learnable layer noise_layer and softplus for scaling) follows common practices seen in literature like the Switch Transformer.3. Non-Linear Gating (MLP Router)While linear routers are common, sometimes a more expressive router using non-linearities can capture more complex routing patterns. Here's a simple example using a two-layer MLP with a ReLU activation.class NonLinearGating(nn.Module): """ Non-Linear Top-k Gating Network using a simple MLP. """ def __init__(self, model_dim: int, num_experts: int, top_k: int, hidden_dim_multiplier=2): super().__init__() self.model_dim = model_dim self.num_experts = num_experts self.top_k = top_k self.hidden_dim = model_dim * hidden_dim_multiplier # Simple MLP: Linear -> ReLU -> Linear self.mlp = nn.Sequential( nn.Linear(self.model_dim, self.hidden_dim), nn.ReLU(), nn.Linear(self.hidden_dim, self.num_experts, bias=False) ) print(f"Initialized NonLinearGating: Dim={model_dim}, Experts={num_experts}, TopK={top_k}, Hidden={self.hidden_dim}") def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass for the non-linear gating network. Args: x (torch.Tensor): Input tensor of shape (Batch Size, Sequence Length, Model Dimension) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - combined_weights (torch.Tensor): Routing weights for selected experts. - expert_indices (torch.Tensor): Indices of selected experts. - raw_logits (torch.Tensor): Raw logits output by the MLP. """ original_shape = x.shape x = x.view(-1, self.model_dim) # Get logits from the MLP # Shape: (B * S, num_experts) raw_logits = self.mlp(x) # Get top-k scores and indices # top_k_logits shape: (B * S, top_k) # top_k_indices shape: (B * S, top_k) top_k_logits, top_k_indices = torch.topk(raw_logits, self.top_k, dim=-1) # Apply softmax to the selected top-k logits # Shape: (B * S, top_k) combined_weights = F.softmax(top_k_logits, dim=-1, dtype=torch.float) # Return weights, indices, and raw logits return combined_weights, top_k_indices, raw_logits # --- Instantiate and Test --- nonlinear_gating = NonLinearGating(model_dim, num_experts, top_k) weights_nl, indices_nl, logits_nl = nonlinear_gating(input_tokens) print("\n--- Non-Linear (MLP) Gating Output ---") print("Input Shape:", input_tokens.shape) print("Weights Shape:", weights_nl.shape) print("Indices Shape:", indices_nl.shape) print("Logits Shape:", logits_nl.shape) print("Example Indices (Token 0):", indices_nl[0])This MLP router introduces more parameters and computation compared to the simple linear router. The choice between linear and non-linear routers depends on the specific task and dataset, often involving empirical validation. More complex router architectures, such as those incorporating attention mechanisms, could also be implemented following similar principles.Using the Gating OutputsThe primary outputs of any gating mechanism are the combined_weights and expert_indices. These are used within the full MoE layer to combine the outputs of the selected experts. While the full MoE layer implementation is not covered in this specific section, the core idea is:Use expert_indices to identify which experts need to process which tokens. This often involves complex dispatching logic in distributed settings (covered in Chapter 4).Pass the relevant tokens through their assigned experts.Use combined_weights to perform a weighted sum of the outputs from the selected experts for each token.For example:# Usage (simplified, assumes expert outputs are gathered) # Assume 'expert_outputs' is a tensor where expert_outputs[i] is the output # for the i-th token from *one* of its assigned experts. # The full implementation requires handling multiple experts per token and gathering results. # Simplified combination for a single token 't' assigned to experts 'e1' and 'e2' # with weights 'w1' and 'w2': # final_output_t = w1 * expert_output_t_e1 + w2 * expert_output_t_e2Router DesignComputational Cost: Non-linear routers are computationally more expensive than linear ones. Noisy routers add minimal overhead during training.Expressiveness vs. Stability: More complex routers might offer better specialization but can sometimes be harder to train stably. Techniques discussed earlier, like router stabilization methods, become important.Integration with Load Balancing: The raw_logits output is important for implementing auxiliary loss functions (Chapter 3) designed to prevent expert load imbalance, a common challenge in MoE training.Top-k Value: The choice of k (number of experts per token) affects computational load and model capacity. k=1 or k=2 are common starting points.This practical exercise demonstrates how to implement different gating strategies. By modifying these building blocks, you can experiment with various architectural ideas discussed in this chapter, tailoring the routing mechanism to the specific needs of your MoE model. The next chapter will focus on the training dynamics, particularly how to use outputs like raw_logits to ensure stable and balanced training.