jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionsJAX thrives on pure functions. A pure function, given the same inputs, always produces the same outputs and has no side effects. Side effects include modifying global variables, printing to the console, or, most relevantly, changing the state of objects outside the function's scope (like updating an attribute of a class instance).
However, many useful computations, especially in machine learning, involve state that changes over time. Think of model parameters being updated during training, momentum values in an optimizer, or the hidden state in a recurrent neural network. How can we reconcile the need for changing state with JAX's requirement for pure functions?
The most fundamental pattern in JAX for handling state is explicit state passing. Instead of modifying state in place, functions are written to:
This pattern treats state just like any other data flowing through your functions. It avoids side effects because the function doesn't modify the original state object; it produces a new one.
Let's consider a very simple example: a counter. In typical imperative Python, you might implement this using a class:
# Imperative (stateful object) approach - Not JAX friendly
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
# Usage
counter_obj = Counter()
print(counter_obj.increment()) # Output: 1
print(counter_obj.increment()) # Output: 2
print(counter_obj.count) # Output: 2
This increment method modifies the object's internal self.count. This is a side effect. If you tried to use jax.jit or jax.grad on such a method directly, it wouldn't work as expected because JAX transformations need to trace the function's operations based on its inputs, and modifications to external or internal object state are opaque to this tracing process.
Now, let's implement the counter using the explicit state passing pattern suitable for JAX:
import jax
# Functional (explicit state passing) approach - JAX friendly
def init_counter_state():
"""Initializes the state."""
return 0 # The state is just an integer
def increment(current_state):
"""Takes the current state, returns the new state."""
print(f"Running increment function with state: {current_state}") # For demonstration
new_state = current_state + 1
# In more complex cases, we might return (new_state, result)
return new_state
# Usage
state = init_counter_state()
print(f"Initial state: {state}")
# Call 1
state = increment(state)
print(f"State after call 1: {state}")
# Call 2
state = increment(state)
print(f"State after call 2: {state}")
Notice the main difference:
count) is explicitly passed into the increment function.increment function performs the calculation (current_state + 1).The increment function itself is pure. Given the input 0, it always returns 1. Given 1, it always returns 2. It doesn't modify anything outside its local scope.
This explicit state passing pattern aligns perfectly with JAX's functional nature and makes stateful computations compatible with transformations like jax.jit:
# Apply JIT to the functional counter
jitted_increment = jax.jit(increment)
state = init_counter_state()
print(f"\nJIT Compiling and Running:")
# First call: Triggers JIT compilation (and runs the Python code)
state = jitted_increment(state)
print(f"State after JIT call 1: {state}")
# Second call: Uses the compiled version (Python print inside won't execute)
state = jitted_increment(state)
print(f"State after JIT call 2: {state}")
# Third call: Still uses the compiled version
state = jitted_increment(state)
print(f"State after JIT call 3: {state}")
When jitted_increment is called the first time with a state of a specific type and shape (here, a scalar integer), JAX traces the increment function, compiles it using XLA, and executes it. The print statement inside increment runs during this initial trace. On subsequent calls with compatible inputs (same type/shape), JAX directly uses the cached, highly optimized compiled code, skipping the Python execution (including the print). The state update happens correctly within the compiled computation because the flow of state (in -> out) was part of the traced logic.
The following diagram illustrates this flow:
A pure function takes the current state and any other inputs, performs computations, and returns the updated state along with any other results.
This pattern extends to jax.grad as well. If the state update involves differentiable operations, JAX can differentiate the function with respect to its inputs (including the state, if needed) because the entire data flow is explicit.
"While our counter example used a simple integer state, applications often involve complex state structures, like nested dictionaries containing model parameters, optimizer statistics (means, variances), etc. JAX provides tools like PyTrees, discussed next, to make managing these structured states using the explicit passing pattern convenient."
Explicit state passing is the foundation of managing mutable processes within JAX's functional framework. By making the state flow explicit, we retain the ability to use JAX's powerful function transformations for acceleration and differentiation.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with