JAX's design, centered around pure functions and transformations like jit
and grad
, requires specific approaches when dealing with computations that inherently involve state. Standard Python patterns often rely on modifying objects in place, which doesn't align well with JAX's functional nature, especially when code needs to be compiled or differentiated. How can we update model parameters during training, manage optimizer states, or handle recurrent computations within these constraints?
This chapter introduces functional programming patterns for managing state in JAX. You will learn about the concept of functional purity and why it matters for JAX transformations. We will cover the primary technique for handling state: explicitly passing it as input to functions and receiving updated state as output. You'll see how JAX's PyTree utility simplifies working with complex, nested state structures (like dictionaries or lists of parameters). We'll apply these ideas through practical examples, such as implementing a stateful counter and managing the state required for a simple optimization algorithm, ensuring these patterns work correctly when combined with jit
, grad
, and other transformations.
6.1 Functional Purity and Side Effects
6.2 The Challenge of State in Functional Code
6.3 Pattern: Explicit State Passing
6.4 Using PyTrees for Structured State
6.5 Example: Stateful Counter
6.6 Example: Simple Optimizer State
6.7 Combining State Management with Transformations
6.8 Practice: Implementing Stateful Functions
© 2025 ApX Machine Learning