A precise mathematical model assembles the high-level components of gating and expert networks. This model details the complete forward pass of a single token through a sparse MoE layer, outlining concrete computational steps.The Gating Network: From Input to Routing WeightsThe process begins with the gating network, also known as the router. Its job is to determine which experts should process the current input token. The input is a token embedding, represented as a vector $x \in \mathbb{R}^d$, where $d$ is the model's hidden dimension.The gating network itself is a simple linear layer, defined by a weight matrix $W_g \in \mathbb{R}^{d \times N}$, where $N$ is the total number of experts. This layer projects the input token into a space of dimension $N$, producing a logit for each expert.$$ h(x) = x \cdot W_g $$The resulting vector $h(x)$ contains $N$ raw scores. To convert these scores into a probability distribution, we apply the softmax function:$$ g(x) = \text{softmax}(h(x)) $$The output, $g(x)$, is a dense $N$-dimensional vector where each element $g(x)_i$ represents the router's confidence in assigning the token to expert $i$. The sum of all elements in $g(x)$ is 1.Enforcing Sparsity with Top-K GatingA dense $g(x)$ vector implies that every expert would contribute to the output, which defeats the computational efficiency goal of MoEs. To enforce sparsity, we employ a TopK operation. Instead of using all experts, we select a small, fixed number, $k$, of the highest-scoring experts.For a given token, we identify the indices of the top $k$ values in $g(x)$ and set all other gating values to zero. This creates a sparse gating vector, $G(x)$. The choice of $k$ is a critical hyperparameter. In Switch Transformers, $k=1$, meaning each token is routed to a single expert. A more common choice is $k=2$, which provides a path for learning more complex functions and adds a degree of redundancy.This operation effectively prunes the computational graph for each token. If $k=2$ and we have $N=64$ experts, we only need to perform the forward pass for 2 of them, ignoring the other 62.The Expert NetworksEach of the $N$ experts is typically an independent feed-forward network (FFN). While they all share the same architecture, they do not share weights. Each expert $E_i$ has its own set of parameters. A standard two-layer FFN expert can be written as:$$ E_i(x) = \text{ReLU}(x \cdot W_{1,i}) \cdot W_{2,i} $$Here, $W_{1,i}$ and $W_{2,i}$ are the weight matrices for the first and second linear layers of expert $i$, respectively. It is this collection of independent expert weights that leads to the dramatic increase in the model's total parameter count.The Complete Forward PassWe can now combine these steps to define the final output, $y(x)$, of the MoE layer. The output is the weighted sum of the outputs from the selected experts, using the sparse gating weights from the TopK operation.$$ y(x) = \sum_{i=1}^{N} G(x)_i \cdot E_i(x) $$Since $G(x)$ is sparse with only $k$ non-zero values, this summation is computationally efficient. We only need to evaluate $E_i(x)$ for the $k$ experts that were selected by the router.The entire data flow for a single token can be visualized as follows:digraph G { rankdir=TB; splines=ortho; node [shape=box, style="rounded,filled", fontname="sans-serif", margin="0.2,0.1"]; edge [fontname="sans-serif", fontsize=10]; subgraph cluster_input { label="Input Token"; style=invis; x [label="Input x", fillcolor="#a5d8ff"]; } subgraph cluster_gating { label="Gating Network (Router)"; bgcolor="#ffec99"; style=rounded; gating_linear [label="Linear (Wg)", fillcolor="#ffe066"]; softmax [label="Softmax", fillcolor="#ffe066"]; topk [label="Top-K Selection\n(k=2)", fillcolor="#ffe066"]; } subgraph cluster_experts { label="Expert Networks"; bgcolor="#e9ecef"; style=rounded; e1 [label="Expert 1 (FFN)", fillcolor="#dee2e6", style="dashed"]; e2 [label="Expert 2 (FFN)", fillcolor="#b2f2bb"]; e_dots [label="...", shape=none]; eN [label="Expert N (FFN)", fillcolor="#b2f2bb"]; } subgraph cluster_output { label="Final Output"; style=invis; sum_op [label="Weighted Sum Σ", shape=circle, fillcolor="#ffc9c9"]; y [label="Output y(x)", fillcolor="#a5d8ff"]; } // Connections x -> gating_linear [label="d-dim vector"]; gating_linear -> softmax [label="N-dim logits"]; softmax -> topk [label="N-dim weights"]; topk -> sum_op [label="G(x) (Sparse Weights)", color="#f03e3e", fontcolor="#f03e3e"]; x -> e1 [style=dashed]; x -> e2; x -> e_dots [style=invis]; x -> eN; e1 -> sum_op [label="E₁(x)", style=dashed]; e2 -> sum_op [label="E₂(x)"]; eN -> sum_op [label="Eɴ(x)"]; sum_op -> y; // Invisible edges for alignment e_dots -> sum_op [style=invis]; // Annotations {rank=same; e1; e2; e_dots; eN;} note [label="In this example, Expert 2 and Expert N\nare selected by the Top-K gating.\nExpert 1 is inactive (dashed lines).", shape=note, fillcolor="#fff9db", style=filled, align=left]; topk -> note [style=invis]; }Data flow for a single token through an MoE layer. The input x is sent to the gating network to produce sparse weights and in parallel to the selected experts for processing.A Note on Differentiability and Weight RenormalizationThe TopK function is non-differentiable, which poses a problem for backpropagation. In practice, this is handled by a straight-through estimator. During the forward pass, we apply the discrete TopK selection. During the backward pass, we pass the gradients through the top $k$ gates as if the selection had been a simple multiplication. The dense gating output $g(x)$ is used to compute the gradients for the gating weights $W_g$.Additionally, after selecting the top $k$ values from the initial softmax output $g(x)$, their sum is no longer guaranteed to be 1. To form a proper convex combination, these $k$ values are often re-normalized. This is typically done by applying a second softmax only to the selected top $k$ logits from $h(x)$. This ensures the weights used in the final summation accurately reflect their relative importance and sum to 1.This formulation provides a model with a massive number of parameters but a constant computational cost per token, determined by $k$ rather than $N$. However, this elegant structure introduces a significant challenge: if the gating network learns to route most tokens to only a few experts, the other experts will not receive training signals. This leads to the problem of expert collapse, which we address next by introducing load balancing losses.