As mentioned in the chapter introduction, the scale of modern machine learning often outstrips the capabilities of a single compute device. Whether dealing with enormous datasets or models containing billions of parameters, distributing the computational load becomes a necessity. This section introduces the fundamental concepts of parallelism employed to tackle these large-scale problems, providing the conceptual groundwork for understanding JAX's distributed computing features, particularly pmap
.
Parallelism, in essence, involves dividing a computational task into smaller, independent or semi-independent parts that can be executed simultaneously on multiple processing units. In the context of machine learning training and inference, several distinct strategies have emerged, primarily driven by whether the bottleneck is the amount of data or the size of the model itself.
This is arguably the most common parallelization strategy for accelerating model training. The core idea is simple: if you have a large dataset, split it into smaller chunks and process each chunk concurrently.
pmean
) or summing (psum
). This aggregated gradient represents the gradient computed over the entire global batch.This approach effectively increases the total batch size that can be processed per step, often leading to faster convergence and better utilization of multiple accelerators. It works best when the model fits comfortably into the memory of a single device, but the dataset is very large. JAX's pmap
function is primarily designed to implement this Single-Program Multiple-Data (SPMD) style of data parallelism.
Flow of data parallelism. Data is sharded, processed in parallel on model replicas, gradients are aggregated, and parameters are updated synchronously.
When a model becomes so large that its parameters, activations, or intermediate states cannot fit into the memory of a single accelerator, data parallelism alone is insufficient. Model parallelism addresses this by partitioning the model itself across multiple devices.
This strategy allows training extremely large models but introduces complexities in partitioning the model effectively and managing the communication overhead between devices. Common approaches include tensor parallelism (splitting individual weight matrices) and pipeline parallelism.
Flow of model parallelism. A single data sample flows through model parts distributed across multiple devices.
Pipeline parallelism is a more sophisticated form of model parallelism designed to improve device utilization. Instead of processing a single batch sequentially through model parts spread across devices (leaving some devices idle while others work), pipeline parallelism divides the batch into smaller micro-batches.
This approach helps mitigate the "bubble" of idle time inherent in naive model parallelism but requires careful management of dependencies, scheduling, and state (like activations needed for backward passes).
Illustration of pipeline parallelism across three devices/stages over time steps (T1-T5). Micro-batches (MB1, MB2, MB3) enter the pipeline sequentially, allowing devices to operate concurrently on different micro-batches.
While all these parallelism strategies are relevant in modern machine learning, this chapter will primarily focus on data parallelism. JAX provides powerful tools for this through its pmap
transformation, which maps a function across multiple devices, automatically handling data distribution and providing mechanisms for collective communication (like gradient aggregation). Understanding data parallelism and pmap
is fundamental for scaling most standard training workloads in JAX. We will delve into device management, the SPMD execution model, collective operations, and practical implementation details in the following sections. Conceptual understanding of model and pipeline parallelism provides valuable context, especially when considering extremely large models discussed later in Chapter 6.
© 2025 ApX Machine Learning