Let's illustrate the explicit state-passing pattern with one of the simplest stateful operations imaginable: incrementing a counter. In standard Python, you might use a class with a method that modifies an internal attribute:
# Standard Python (mutable state) - Not JAX-friendly
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
counter = Counter()
print(counter.increment()) # Output: 1
print(counter.increment()) # Output: 2
This approach relies on side effects (modifying self.count
). As discussed, JAX transformations like jit
work best with pure functions. If we tried to jit
the increment
method directly, JAX wouldn't be able to track the change to self.count
across multiple calls in a compiled context because the change happens in place.
To make this compatible with JAX's functional approach, we redefine the operation as a pure function. This function takes the current state (the count) as an argument and returns the new state (the incremented count).
import jax
import jax.numpy as jnp
# The state is just an integer (or a JAX scalar)
initial_state = 0
# Pure function: takes state, returns new state
def update_counter(current_state):
"""Increments the input state by 1."""
print(f"Tracing update_counter with state: {current_state}") # Added for tracing demo
new_state = current_state + 1
return new_state
# Explicitly pass state through the function
state_after_1_update = update_counter(initial_state)
state_after_2_updates = update_counter(state_after_1_update)
state_after_3_updates = update_counter(state_after_2_updates)
print(f"\nInitial State: {initial_state}")
print(f"State after 1 update: {state_after_1_update}")
print(f"State after 2 updates: {state_after_2_updates}")
print(f"State after 3 updates: {state_after_3_updates}")
Notice how the state is explicitly threaded through the function calls. The update_counter
function itself doesn't change anything; it just computes a new value based on its input.
Flow of state through the pure
update_counter
function. The output state from one call becomes the input for the next.
Because update_counter
is a pure function, we can safely apply JAX transformations like jax.jit
to it.
# JIT-compile the pure function
jitted_update_counter = jax.jit(update_counter)
# First call: JAX traces the function and compiles it
print("\nApplying JIT:")
jitted_state_1 = jitted_update_counter(initial_state)
print(f"State after 1 jitted update: {jitted_state_1}")
# Second call: Uses the cached, compiled version (no print from inside)
jitted_state_2 = jitted_update_counter(jitted_state_1)
print(f"State after 2 jitted updates: {jitted_state_2}")
# Third call: Also uses the compiled version
jitted_state_3 = jitted_update_counter(jitted_state_2)
print(f"State after 3 jitted updates: {jitted_state_3}")
You'll observe the "Tracing update_counter..." message only appears once (or perhaps a couple of times if the input type changes slightly, e.g., from Python int to JAX tracer). Subsequent calls use the optimized, compiled code generated by XLA, demonstrating that jit
successfully handled our stateful computation because the state was managed functionally.
lax.scan
Manually chaining function calls works, but it's verbose and inefficient for many steps. A more idiomatic and performant way to handle sequences of stateful updates in JAX is using jax.lax.scan
. This function is essentially a functional loop construct optimized for compilation.
jax.lax.scan
repeatedly applies a function (the "body function") to a carry-over state. The body function takes the current carry
state and an optional input slice x
(from an input sequence xs
) and returns the new_carry
state and an optional output y
for that step.
# Define the body function for scan
# Takes (carry_state, optional_input_slice)
# Returns (new_carry_state, optional_per_step_output)
def scan_body(carry_state, _): # We don't have per-step inputs here, so use _
"""The function to apply at each step of the scan."""
next_state = update_counter(carry_state)
# We don't need a per-step output, just the final state
per_step_output = None
return next_state, per_step_output
num_steps = 10
initial_scan_state = 0
print(f"\nUsing lax.scan for {num_steps} steps:")
# Run the scan
# scan(f, init_carry, xs, length)
# Here, xs is None, so we specify the number of steps via length
final_state, accumulated_outputs = jax.lax.scan(
scan_body,
initial_scan_state,
xs=None, # No input sequence needed for a simple counter
length=num_steps
)
print(f"Initial state for scan: {initial_scan_state}")
print(f"Final state after scan: {final_state}")
# accumulated_outputs will be None because we returned None per step
# We can also JIT the entire scan operation for maximum efficiency
@jax.jit
def run_scan_jitted(init_state, steps):
final_st, _ = jax.lax.scan(scan_body, init_state, xs=None, length=steps)
return final_st
print("\nUsing JITted lax.scan:")
final_state_jitted = run_scan_jitted(initial_scan_state, num_steps)
# Note: The "Tracing..." message from update_counter might appear during the
# jax.jit tracing of run_scan_jitted, but not during execution.
print(f"Final state after jitted scan: {final_state_jitted}")
Using lax.scan
allows JAX to compile the entire loop into a single, optimized kernel when combined with jit
, which is significantly faster than executing a Python loop of jitted functions.
While this counter uses a simple integer state, the exact same pattern applies to more complex states, such as the nested dictionaries or lists of JAX arrays used for model parameters or optimizer statistics. As long as your update function takes the entire state structure (as a PyTree) as input and returns the entire updated state structure, JAX transformations will handle it correctly. We'll see this in action when we look at managing optimizer state next.
This stateful counter example demonstrates the core functional pattern for managing state in JAX: treat state as an immutable value passed explicitly into and out of pure functions. This approach ensures compatibility with JAX's transformations like jit
and enables efficient execution patterns like lax.scan
.
© 2025 ApX Machine Learning