You've learned how to manage state in JAX by adopting functional patterns, primarily by passing state explicitly into functions and returning the updated state. This approach maintains functional purity, which is essential for compatibility with JAX's transformations. Now, let's examine how these state management techniques work when combined with jax.jit
, jax.grad
, jax.vmap
, and jax.pmap
. The good news is that because our stateful functions remain pure, they compose cleanly with these transformations.
Using jax.jit
with functions that manage state is straightforward. Since the function takes the current state as an argument and returns the new state, jit
can trace the function just like any other pure function. The state, often represented as a PyTree, is treated as a regular input and output by the tracer.
Let's revisit the simple stateful counter example:
import jax
import jax.numpy as jnp
def stateful_counter(count_state, increment):
"""Increments a counter state."""
new_count = count_state['count'] + increment
return {'count': new_count} # Return the new state
# Initial state
initial_state = {'count': 0}
# Apply the function
state1 = stateful_counter(initial_state, jnp.array(1))
state2 = stateful_counter(state1, jnp.array(5))
print(f"Initial state: {initial_state}")
print(f"State after first increment: {state1}")
print(f"State after second increment: {state2}")
Now, let's compile stateful_counter
using jax.jit
:
# JIT-compile the function
jit_counter = jax.jit(stateful_counter)
# Run the compiled version
jitted_state1 = jit_counter(initial_state, jnp.array(1))
jitted_state2 = jit_counter(jitted_state1, jnp.array(5)) # Re-uses compiled code if shapes/types match
print(f"\nJITted state after first increment: {jitted_state1}")
print(f"JITted state after second increment: {jitted_state2}")
# Verify the structure and values are the same
assert jax.tree_util.tree_all(jax.tree_map(lambda x, y: jnp.all(x == y), state2, jitted_state2))
As you can see, jit
handles the state dictionary (a PyTree) without issues. JAX traces the function with the initial state structure and argument types/shapes. Subsequent calls with matching structures and types reuse the compiled code, providing significant speedups for complex stateful computations like neural network training steps.
Important Note: Remember that jit
traces the function based on the structure of the state PyTree and the types/shapes of its leaf nodes (the arrays). If the structure of your state changes between calls (e.g., adding new keys to a dictionary), jit
will need to recompile the function, potentially impacting performance. It's best practice to maintain a consistent state structure.
Automatic differentiation with jax.grad
also integrates well with explicit state passing. Often, the state contains the parameters we want to differentiate with respect to (e.g., model weights). Functions typically return both a value to be differentiated (like a loss) and the updated state.
Consider a simple function that calculates a squared error loss and updates a parameter state:
import jax
import jax.numpy as jnp
def predict_and_update(params, x):
"""A simple linear prediction function. State = params."""
# Prediction uses the parameters from the state
pred = params['w'] * x + params['b']
# Return prediction (value) and unchanged state
return pred, params # State is not modified here
def loss_fn(params, x, y_target):
"""Calculates loss and doesn't change state."""
pred, _ = predict_and_update(params, x) # Use the predict function
loss = jnp.mean((pred - y_target)**2)
# Return only the loss value; state is handled separately
return loss
# Example parameters (state) and data
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([3.5, 5.5, 7.5]) # Ideal: w=2.5, b=1.0
# Calculate the loss
current_loss = loss_fn(params_state, x_data, y_target_data)
print(f"Current loss: {current_loss}")
# Get the gradient function for the loss w.r.t. 'params' (arg 0)
grad_loss_fn = jax.grad(loss_fn, argnums=0) # Differentiate w.r.t. params
# Calculate gradients
grads = grad_loss_fn(params_state, x_data, y_target_data)
print(f"Gradients: {grads}")
Here, loss_fn
takes params
(our state) as input. We use jax.grad
specifying argnums=0
to get gradients with respect to params
. JAX correctly traces through the predict_and_update
function used inside loss_fn
and computes the gradients for w
and b
.
Often, you'll want both the loss value and the gradients. jax.value_and_grad
is perfect for this:
# Get a function that returns both loss and gradients
value_and_grad_fn = jax.value_and_grad(loss_fn, argnums=0)
# Calculate loss and gradients simultaneously
loss_val, grads_val = value_and_grad_fn(params_state, x_data, y_target_data)
print(f"\nUsing value_and_grad:")
print(f"Loss: {loss_val}")
print(f"Gradients: {grads_val}")
Now, let's combine this with a state update step, simulating a single step of gradient descent:
def training_step(params, x, y_target, learning_rate):
"""Performs one step of gradient descent."""
loss, grads = jax.value_and_grad(loss_fn, argnums=0)(params, x, y_target)
# Update params using the gradients (explicit state update)
# jax.tree_map applies a function element-wise to PyTrees
updated_params = jax.tree_map(
lambda p, g: p - learning_rate * g, params, grads
)
# Return the loss and the new state (updated parameters)
return loss, updated_params
# Perform one training step
learning_rate = 0.1
loss_step1, params_step1 = training_step(params_state, x_data, y_target_data, learning_rate)
print(f"\nAfter one training step:")
print(f"Loss: {loss_step1}")
print(f"Updated Params: {params_step1}")
# Perform another step
loss_step2, params_step2 = training_step(params_step1, x_data, y_target_data, learning_rate)
print(f"\nAfter second training step:")
print(f"Loss: {loss_step2}")
print(f"Updated Params: {params_step2}")
This training_step
function takes the parameter state, computes gradients, and returns the updated parameter state. It's a pure function, making it suitable for further transformations like jit
.
vmap
jax.vmap
allows you to automatically vectorize functions, including those managing state. This is extremely useful for processing batches of data. You need to tell vmap
how each argument, including the state, should be mapped over the batch dimension using the in_axes
argument.
Let's modify our counter to operate on batches. Suppose we have a batch of increments and want independent counters (though sharing the same logic):
# Define the initial state for a batch of counters
# Assume 3 counters in our batch
batch_size = 3
# State needs to be replicated or batched accordingly
batched_initial_state = {'count': jnp.array([0, 0, 0])} # Batched state
# Batch of increments
batched_increments = jnp.array([1, 5, 10])
# Vectorize the counter function over the state ('count') and the increment
# axis 0 for state['count'], axis 0 for increment
vmap_counter = jax.vmap(stateful_counter, in_axes=({'count': 0}, 0))
# Apply the vectorized function
batched_state1 = vmap_counter(batched_initial_state, batched_increments)
print(f"\nVectorized Counter:")
print(f"Batched initial state: {batched_initial_state}")
print(f"Batched increments: {batched_increments}")
print(f"Batched state after increment: {batched_state1}")
# Apply again with different increments
batched_increments2 = jnp.array([2, 3, 4])
batched_state2 = vmap_counter(batched_state1, batched_increments2)
print(f"Batched state after second increment: {batched_state2}")
Here, in_axes=({'count': 0}, 0)
tells vmap
:
count_state
), look inside the dictionary. For the key 'count'
, map along axis 0.increment
), map along axis 0.The out_axes
argument (which defaults to 0 for all outputs) specifies how the outputs are structured. In this case, the returned state dictionary {'count': ...}
has its 'count'
value stacked along axis 0.
In machine learning, it's common to have shared model parameters across a batch. vmap
handles this easily by setting the corresponding in_axes
to None
.
# Example: Applying predict_and_update over a batch of x, using the *same* params
# Params (state) are shared: in_axes=None
# x_data is batched: in_axes=0
vmap_predict = jax.vmap(predict_and_update, in_axes=(None, 0))
# Our original single parameter state
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
# Batch of x data
x_batch = jnp.array([1.0, 2.0, 3.0, 4.0])
# Run the vectorized prediction
# Output state will be replicated, prediction will be batched
batched_preds, batched_params_out = vmap_predict(params_state, x_batch)
print(f"\nVectorized Prediction (Shared Params):")
print(f"Input Params (State): {params_state}")
print(f"Input x batch: {x_batch}")
print(f"Batched Predictions: {batched_preds}")
# Note: The output state is just the input state replicated by vmap
# print(f"Output Params (State): {batched_params_out}") # Would show replicated params
pmap
Similar principles apply when using jax.pmap
for parallelization across multiple devices (GPUs/TPUs), although the details involve device placement and collective operations. Like vmap
, pmap
requires specifying how inputs (including state) are mapped to devices using in_axes
.
in_axes=None
). Each device holds a full copy.in_axes=0
).When state is updated in parallel (e.g., calculating gradients on different data shards), you often need collective operations (lax.psum
, lax.pmean
, etc.) within the pmap
ped function to aggregate results (like averaging gradients) before updating a replicated state. Handling state with pmap
requires careful consideration of data distribution and synchronization, building upon the concepts discussed in the pmap
chapter.
The true utility emerges when you compose these transformations. For instance, a typical machine learning training loop involves calculating gradients for a batch of data (vmap
+ grad
) and compiling the entire step for performance (jit
).
Let's JIT-compile our training_step
function:
# JIT-compile the training step function
jit_training_step = jax.jit(training_step)
# Reset state
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
learning_rate = 0.1
x_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([3.5, 5.5, 7.5])
print(f"\nJITted Training Step:")
print(f"Initial Params: {params_state}")
# Run the compiled training step
loss_jitted1, params_jitted1 = jit_training_step(params_state, x_data, y_target_data, learning_rate)
print(f"Step 1 Loss: {loss_jitted1}, Params: {params_jitted1}")
# Run again (should be faster due to compilation cache)
loss_jitted2, params_jitted2 = jit_training_step(params_jitted1, x_data, y_target_data, learning_rate)
print(f"Step 2 Loss: {loss_jitted2}, Params: {params_jitted2}")
This jit_training_step
now efficiently performs gradient calculation and parameter updates. We could further wrap this with vmap
if we needed to process multiple independent training batches simultaneously (less common) or integrate it within a pmap
for distributed training.
The explicit state passing pattern, combined with PyTrees, provides a robust way to handle state that works harmoniously with JAX's core transformations, enabling complex, high-performance computations like training large machine learning models.
© 2025 ApX Machine Learning