Training contemporary machine learning models often pushes the limits of single accelerator memory and compute capacity. This chapter focuses on the techniques necessary to scale JAX applications effectively.
We will examine how to structure large models using libraries like Flax and Haiku, and how to integrate pmap
for distributed data parallelism within these frameworks. You will learn practical strategies such as gradient accumulation to simulate larger effective batch sizes, gradient checkpointing (jax.checkpoint
) to reduce memory usage at the cost of recomputation, and mixed precision training for further memory savings and potential speedups. We will also introduce concepts related to model parallelism and discuss optimizers suited for distributed settings.
By the end of this chapter, you will understand how to combine these different JAX features and ecosystem tools to address the challenges of training large neural networks.
6.1 Overview of Challenges in Large Model Training
6.2 Introduction to JAX Ecosystem Libraries (Flax, Haiku)
6.3 Managing Model Parameters and State
6.4 Combining pmap with Training Frameworks
6.5 Gradient Accumulation
6.6 Gradient Checkpointing (Re-materialization)
6.7 Mixed Precision Training
6.8 Conceptual Model Parallelism Strategies
6.9 Optimization Algorithms for Large Scale
6.10 Practice: Implementing Gradient Checkpointing
© 2025 ApX Machine Learning