While Expert Parallelism (EP) effectively addresses the memory constraints imposed by storing numerous experts within MoE layers by distributing them across multiple devices, it doesn't inherently scale the overall data processing throughput beyond the capacity of the devices holding the experts. For truly massive datasets and accelerated training times, we need to process larger global batch sizes. This necessitates integrating Expert Parallelism with the familiar strategy of Data Parallelism (DP).
Data Parallelism, in its standard form, replicates the entire model across multiple devices. Each device processes a different slice of the input data batch, computes gradients locally, and then synchronizes these gradients (typically via an All-Reduce operation) before updating the model weights. However, as highlighted in the chapter introduction, replicating a large MoE model entirely on each device quickly becomes infeasible due to the sheer number of parameters, even if many are inactive for a given input.
Combining EP and DP provides a powerful two-dimensional scaling approach. Imagine your available processing units (e.g., GPUs) arranged conceptually in a grid.
- Expert Parallelism Dimension: Along one dimension (say, rows of the grid), devices form an EP group. The experts of each MoE layer are sharded across the devices within this group. Communication within this group primarily involves the All-to-All pattern to shuffle tokens to their designated experts during the forward pass and shuffle gradients back during the backward pass.
- Data Parallelism Dimension: Along the other dimension (say, columns of the grid), devices form a DP group. Each device in a DP group holds a complete instance of the model, but importantly, this instance already incorporates the expert sharding defined by its EP group. Each DP replica processes a unique micro-batch of the global data batch. Communication across DP replicas involves All-Reduce operations to aggregate gradients for the non-expert parameters (like attention layers, embeddings, and the gating network weights) and also for the gradients computed by the experts within each EP group.
How it Works in Practice
Consider a training setup with N total devices, configured into Ndp Data Parallel groups, where each group contains Nep devices for Expert Parallelism (N=Ndp×Nep).
- Data Distribution: The global input batch is split into Ndp chunks. Each DP group receives one chunk.
- Forward Pass (within a DP group):
- Each device within a DP group processes its assigned micro-batch through the initial layers of the model (non-MoE).
- When an MoE layer is encountered, the gating network (replicated across all devices) determines the target expert(s) for each token.
- An All-to-All communication operation occurs within the EP group (i.e., among the Nep devices) to send each token representation to the device holding its assigned expert.
- Each device computes the output for the tokens it received using its local expert(s).
- Another All-to-All communication occurs within the EP group to send the expert outputs back to the originating devices based on the initial token distribution.
- Tokens proceed through the rest of the model layers on their original device within the DP group.
- Backward Pass:
- Gradients flow back through the network.
- For MoE layers, gradients need to be routed back to the experts that computed the forward pass. This involves another All-to-All within the EP group for gradient shuffling.
- Gradients for the expert parameters are computed locally on the device holding the expert.
- Gradients for the shared parameters (non-expert layers and gating networks) are computed on all devices.
- Gradient Aggregation:
- An All-Reduce operation occurs across the DP group (i.e., among the Ndp corresponding devices in different EP groups) to average the gradients for all shared parameters.
- Crucially, an All-Reduce also occurs across the DP group for the gradients of the expert parameters. Although each expert only exists on one device within an EP group, its gradients need to be aggregated across all DP replicas that used that corresponding expert shard.
- Optimizer Step: Each device updates its local copy of the shared parameters and its local shard of the expert parameters using the aggregated gradients.
Visualizing the Configuration
We can represent this 2D parallelism using a device mesh. For example, with 8 GPUs, we could have a Ndp=2,Nep=4 configuration.
A 2x4 device configuration. GPUs 0-3 form one DP replica using 4-way EP. GPUs 4-7 form a second DP replica, also using 4-way EP. All-to-All communication happens within rows (EP groups). All-Reduce communication happens across columns (DP groups) connecting devices with the same EP rank.
Communication and Trade-offs
This combined approach introduces both All-to-All (within EP groups) and All-Reduce (across DP groups) communication collectives. The efficiency depends heavily on the underlying hardware interconnect topology and the relative costs of these operations.
- Increasing Nep (wider EP groups): Allows for more experts or larger experts per layer, scaling model size. However, it increases the size and potential latency of the All-to-All operations.
- Increasing Ndp (more DP replicas): Increases the global batch size that can be processed, improving throughput. However, it increases the number of devices participating in the All-Reduce operations and requires more memory to store the replicated shared parameters.
Choosing the optimal (Ndp, Nep) configuration requires careful consideration of the specific model architecture (number of experts, size of experts vs. shared layers), the available hardware (number of devices, interconnect bandwidth/latency), and the desired global batch size. Frameworks like DeepSpeed often provide tools and abstractions to manage these process groups and communication patterns, simplifying the configuration for the user. However, understanding the underlying mechanics is essential for performance tuning and troubleshooting complex distributed MoE training runs. This integration forms the foundation for scaling MoEs to hundreds or thousands of processors, enabling the training of models with trillions of parameters.