As introduced earlier, JAX thrives on pure functions. A pure function, given the same inputs, always produces the same outputs and has no side effects. Side effects include modifying global variables, printing to the console, or, most relevantly here, changing the state of objects outside the function's scope (like updating an attribute of a class instance).
However, many useful computations, especially in machine learning, involve state that changes over time. Think of model parameters being updated during training, momentum values in an optimizer, or the hidden state in a recurrent neural network. How can we reconcile the need for changing state with JAX's requirement for pure functions?
The most fundamental pattern in JAX for handling state is explicit state passing. Instead of modifying state in place, functions are written to:
This pattern treats state just like any other data flowing through your functions. It avoids side effects because the function doesn't modify the original state object; it produces a new one.
Let's consider a very simple example: a counter. In typical imperative Python, you might implement this using a class:
# Imperative (stateful object) approach - Not JAX friendly
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
# Usage
counter_obj = Counter()
print(counter_obj.increment()) # Output: 1
print(counter_obj.increment()) # Output: 2
print(counter_obj.count) # Output: 2
This increment
method modifies the object's internal self.count
. This is a side effect. If you tried to use jax.jit
or jax.grad
on such a method directly, it wouldn't work as expected because JAX transformations need to trace the function's operations based on its inputs, and modifications to external or internal object state are opaque to this tracing process.
Now, let's implement the counter using the explicit state passing pattern suitable for JAX:
import jax
# Functional (explicit state passing) approach - JAX friendly
def init_counter_state():
"""Initializes the state."""
return 0 # The state is just an integer
def increment(current_state):
"""Takes the current state, returns the new state."""
print(f"Running increment function with state: {current_state}") # For demonstration
new_state = current_state + 1
# In more complex cases, we might return (new_state, result)
return new_state
# Usage
state = init_counter_state()
print(f"Initial state: {state}")
# Call 1
state = increment(state)
print(f"State after call 1: {state}")
# Call 2
state = increment(state)
print(f"State after call 2: {state}")
Notice the key difference:
count
) is explicitly passed into the increment
function.increment
function performs the calculation (current_state + 1
).The increment
function itself is pure. Given the input 0
, it always returns 1
. Given 1
, it always returns 2
. It doesn't modify anything outside its local scope.
This explicit state passing pattern aligns perfectly with JAX's functional nature and makes stateful computations compatible with transformations like jax.jit
:
# Apply JIT to the functional counter
jitted_increment = jax.jit(increment)
state = init_counter_state()
print(f"\nJIT Compiling and Running:")
# First call: Triggers JIT compilation (and runs the Python code)
state = jitted_increment(state)
print(f"State after JIT call 1: {state}")
# Second call: Uses the compiled version (Python print inside won't execute)
state = jitted_increment(state)
print(f"State after JIT call 2: {state}")
# Third call: Still uses the compiled version
state = jitted_increment(state)
print(f"State after JIT call 3: {state}")
When jitted_increment
is called the first time with a state of a specific type and shape (here, a scalar integer), JAX traces the increment
function, compiles it using XLA, and executes it. The print
statement inside increment
runs during this initial trace. On subsequent calls with compatible inputs (same type/shape), JAX directly uses the cached, highly optimized compiled code, skipping the Python execution (including the print
). The state update happens correctly within the compiled computation because the flow of state (in -> out) was part of the traced logic.
The following diagram illustrates this flow:
A pure function takes the current state and any other inputs, performs computations, and returns the updated state along with any other results.
This pattern extends seamlessly to jax.grad
as well. If the state update involves differentiable operations, JAX can differentiate the function with respect to its inputs (including the state, if needed) because the entire data flow is explicit.
While our counter example used a simple integer state, real-world applications often involve complex state structures, like nested dictionaries containing model parameters, optimizer statistics (means, variances), etc. JAX provides tools like PyTrees, discussed next, to make managing these structured states using the explicit passing pattern convenient.
Explicit state passing is the cornerstone of managing mutable processes within JAX's functional framework. By making the state flow explicit, we retain the ability to use JAX's powerful function transformations for acceleration and differentiation.
© 2025 ApX Machine Learning