Okay, let's put the theory of managing state into practice. The core idea, as we've discussed, is functional purity: functions shouldn't modify external state or have hidden side effects. Instead, they should take the current state as input and return the new state as output, along with any results. This "state-in, state-out" pattern is fundamental when working with JAX transformations like jit
and grad
.
These exercises will help you become comfortable with implementing this pattern for different scenarios. We'll start simple and build up to a task resembling a component of machine learning optimization. Remember that JAX often uses PyTrees (nested tuples, lists, dictionaries) to handle complex states easily.
Let's revisit the stateful counter example. Your task is to implement a function update_counter
that takes the current count (state) and an increment value. It should return the new count and the increment value itself as an auxiliary output. Then, apply jax.jit
to this function and test it.
Instructions:
jax.numpy
.update_counter
function. It should accept count
and increment
as arguments.new_count = count + 1
.(new_count, increment)
. The first element is the updated state, the second is the auxiliary output.jit
-compiled version of this function using jax.jit
.count = 0
).for
loop to call the jit
-compiled function multiple times (e.g., 5 times). In each iteration:
count
and a sample increment
value (e.g., the loop index).count
(updating the state for the next iteration) and the returned_increment
.import jax
import jax.numpy as jnp
# 1. Define the stateful function
def update_counter(count, increment):
"""Increments the count and returns the new count and the increment value."""
new_count = count + 1
# Return (new_state, auxiliary_output)
return new_count, increment
# 2. JIT-compile the function
jitted_update_counter = jax.jit(update_counter)
# 3. Initialize state
current_count = 0
print(f"Initial count: {current_count}")
# 4. Run the loop, updating state each time
num_steps = 5
for i in range(num_steps):
# Pass current state, get back new state and output
current_count, returned_increment = jitted_update_counter(current_count, i)
print(f"Step {i+1}: New Count = {current_count}, Returned Increment = {returned_increment}")
Expected Output:
Initial count: 0
Step 1: New Count = 1, Returned Increment = 0
Step 2: New Count = 2, Returned Increment = 1
Step 3: New Count = 3, Returned Increment = 2
Step 4: New Count = 4, Returned Increment = 3
Step 5: New Count = 5, Returned Increment = 4
This simple example demonstrates the core pattern: the loop manages the state (current_count
) outside the compiled function, passing it in and receiving the updated version back in each step.
Now, let's implement a function to compute a simple moving average. A moving average updates based on new incoming values. We need to keep track of the sum of values seen so far and the total number of values.
Instructions:
update_moving_average
that takes state
and new_value
as input.state
will be a tuple (current_sum, count)
. Initialize it appropriately (e.g., (0.0, 0)
).state
tuple.new_sum = current_sum + new_value
.new_count = count + 1
.current_average = new_sum / new_count
. Handle the initial case where count
might be 0 if necessary (though adding 1 avoids division by zero here).new_state = (new_sum, new_count)
.(new_state, current_average)
.update_moving_average
for each number. Update the state variable in each iteration and print the calculated average.jax.jit
to update_moving_average
and observe if it works correctly.import jax
import jax.numpy as jnp
# 1. Define the stateful function for moving average
def update_moving_average(state, new_value):
"""Updates the moving average state with a new value."""
current_sum, count = state # Unpack state
new_sum = current_sum + new_value
new_count = count + 1
current_average = new_sum / new_count
new_state = (new_sum, new_count) # Package new state
return new_state, current_average
# Optional: JIT-compile the function
# jitted_update_moving_average = jax.jit(update_moving_average)
# Use jitted_update_moving_average below if you uncomment this
# 2. Sample data and initial state
data_sequence = jnp.array([2.0, 4.0, 6.0, 8.0, 10.0])
initial_state = (0.0, 0) # (sum, count)
print(f"Initial state (sum, count): {initial_state}")
print(f"Data sequence: {data_sequence}")
# 3. Iterate and update
current_state = initial_state
for i, value in enumerate(data_sequence):
# Pass current state and value, get back new state and average
current_state, avg = update_moving_average(current_state, value)
# Use this line instead if JITting:
# current_state, avg = jitted_update_moving_average(current_state, value)
print(f"After value {value:.1f}: New State = ({current_state[0]:.1f}, {current_state[1]}), Moving Average = {avg:.2f}")
Expected Output:
Initial state (sum, count): (0.0, 0)
Data sequence: [ 2. 4. 6. 8. 10.]
After value 2.0: New State = (2.0, 1), Moving Average = 2.00
After value 4.0: New State = (6.0, 2), Moving Average = 3.00
After value 6.0: New State = (12.0, 3), Moving Average = 4.00
After value 8.0: New State = (20.0, 4), Moving Average = 5.00
After value 10.0: New State = (30.0, 5), Moving Average = 6.00
Here, the state is a tuple, a simple PyTree. The function correctly updates and returns this structured state along with the calculated average. JIT compilation should work seamlessly because the function is pure and follows the state-passing pattern.
This exercise simulates a single update step in an optimization algorithm like gradient descent. We'll manage the parameter being optimized as the state. Our goal is to minimize a simple function, for instance f(x)=x2.
Instructions:
loss_fn(x) = x**2
.jax.grad
: grad_fn = jax.grad(loss_fn)
.gradient_descent_step
function. It should take params
(the current value of x) and learning_rate
as arguments.params
using grad_fn
.new_params = params - learning_rate * gradient_value
.new_params
(this is the updated state).params
(e.g., jnp.array(5.0)
).learning_rate
(e.g., 0.1
).gradient_descent_step
, passing the current params
and learning_rate
.params
variable with the returned new_params
.jax.jit
to gradient_descent_step
. Does it work? Why?import jax
import jax.numpy as jnp
# 1. Define the loss function
def loss_fn(x):
return x**2
# 2. Get the gradient function
grad_fn = jax.grad(loss_fn)
# 3. Define the state update function (gradient descent step)
def gradient_descent_step(params, learning_rate):
"""Performs one step of gradient descent."""
gradient_value = grad_fn(params)
new_params = params - learning_rate * gradient_value
# Return the new state (updated parameters)
return new_params
# Optional: JIT-compile the step function
# jitted_gradient_descent_step = jax.jit(gradient_descent_step, static_argnums=(1,))
# Note: learning_rate is often treated as static if it doesn't change per step.
# 4. Initialization
current_params = jnp.array(5.0)
learning_rate = 0.1
num_steps = 5
print(f"Initial parameters: {current_params}")
print(f"Learning rate: {learning_rate}")
# 5. Perform update steps
for i in range(num_steps):
# Pass current state (params), get back new state
current_params = gradient_descent_step(current_params, learning_rate)
# Use this line instead if JITting:
# current_params = jitted_gradient_descent_step(current_params, learning_rate)
print(f"Step {i+1}: Updated parameters = {current_params:.4f}")
Expected Output:
Initial parameters: 5.0
Learning rate: 0.1
Step 1: Updated parameters = 4.0000
Step 2: Updated parameters = 3.2000
Step 3: Updated parameters = 2.5600
Step 4: Updated parameters = 2.0480
Step 5: Updated parameters = 1.6384
This exercise demonstrates how the state (the model parameters params
) is explicitly managed outside the update function. The gradient_descent_step
function is pure; it calculates the new state based on the inputs and returns it. This pattern is essential when building optimizers or training loops in JAX. Notice that if you jit
this function, learning_rate
is often best marked as a static argument (using static_argnums
or static_argnames
) if it doesn't change within the compiled function's scope, as this can improve compilation efficiency.
These practical exercises reinforce the functional approach to state management required by JAX. By explicitly passing state in and out, your functions remain pure and compatible with JAX's powerful transformations like jit
, grad
, vmap
, and pmap
. This pattern scales effectively from simple counters to complex states involving nested PyTrees for neural network parameters and optimizer statistics.
© 2025 ApX Machine Learning