Let's apply the explicit state passing pattern to a common scenario in machine learning: managing the state of an optimization algorithm. Optimizers like Stochastic Gradient Descent (SGD) or Adam update model parameters iteratively based on computed gradients. Besides the parameters themselves, many optimizers also maintain their own internal state, such as momentum vectors or adaptive learning rates.
Consider the basic Gradient Descent update rule. To find the minimum of a function f(params), we iteratively update the parameters params by moving in the direction opposite to the gradient ∇f(params):
paramsnew=paramsold−learning_rate×∇f(paramsold)In a functional setting, we can't modify params
in place. Instead, we need a function that takes the current parameters and gradients, and returns the new parameters.
Let's try to minimize a simple quadratic function f(x)=(x−3)2. The minimum is clearly at x=3. The gradient is ∇f(x)=2(x−3).
import jax
import jax.numpy as jnp
# Define the function to minimize
def objective_function(x):
return (x - 3.0)**2
# Calculate the gradient using jax.grad
grad_fn = jax.grad(objective_function)
# Define the optimizer update step (simple SGD)
# This function takes the current state (just 'params' here)
# and the gradients, and returns the new state.
def sgd_update(params, gradients, learning_rate):
"""Performs a single SGD update step."""
new_params = params - learning_rate * gradients
# Return the updated state
return new_params
# --- Optimization Loop ---
# Initial parameter value (our initial state)
current_params = jnp.array(0.0)
learning_rate = 0.1
num_steps = 20
print(f"Initial params: {current_params:.4f}")
# Run the optimization loop
for step in range(num_steps):
# 1. Calculate gradients for the current parameters
gradients = grad_fn(current_params)
# 2. Compute the new parameters using the update function
# Pass the current state ('current_params') and gradients
# Receive the new state ('next_params')
next_params = sgd_update(current_params, gradients, learning_rate)
# 3. Update the state for the next iteration
current_params = next_params
if (step + 1) % 5 == 0:
print(f"Step {step+1:3d}, Params: {current_params:.4f}, Gradient: {gradients:.4f}")
print(f"\nFinal optimized params: {current_params:.4f}")
In this loop, current_params
holds the state. The sgd_update
function is pure; it takes the state and gradients and returns a new state (next_params
). We then explicitly reassign current_params = next_params
to carry the state forward. This pattern works perfectly with JAX transformations.
Many optimizers require additional state. Let's implement SGD with momentum. The update involves a velocity term v:
vnew=momentum×vold+learning_rate×∇f(paramsold)paramsnew=paramsold−vnewNotice that we now need to track both params
and velocity
(v). This combined information forms the optimizer's state. We can use a PyTree, like a dictionary, to hold this structured state.
import jax
import jax.numpy as jnp
from typing import NamedTuple # Or use a dictionary
# Define the function to minimize (same as before)
def objective_function(x):
return (x - 3.0)**2
# Calculate the gradient using jax.grad
grad_fn = jax.grad(objective_function)
# Define a structure for the optimizer state using NamedTuple (or a dict)
class OptimizerState(NamedTuple):
params: jax.Array
velocity: jax.Array
# Define the optimizer update step (SGD with Momentum)
# This function takes the combined state (params and velocity)
# and the gradients, and returns the new combined state.
def momentum_update(state: OptimizerState, gradients, learning_rate, momentum):
"""Performs a single SGD with momentum update step."""
new_velocity = momentum * state.velocity + learning_rate * gradients
new_params = state.params - new_velocity
# Return the updated state as a new OptimizerState object
return OptimizerState(params=new_params, velocity=new_velocity)
# --- Optimization Loop with Momentum ---
# Initial parameter value and velocity
initial_params = jnp.array(0.0)
initial_velocity = jnp.array(0.0)
# Initial state is now a structure containing params and velocity
current_state = OptimizerState(params=initial_params, velocity=initial_velocity)
learning_rate = 0.1
momentum_coeff = 0.9 # Common momentum value
num_steps = 20
print(f"Initial state: params={current_state.params:.4f}, velocity={current_state.velocity:.4f}")
# Run the optimization loop
for step in range(num_steps):
# 1. Calculate gradients for the current parameters within the state
gradients = grad_fn(current_state.params)
# 2. Compute the new state using the momentum update function
# Pass the current combined state and gradients
# Receive the new combined state
next_state = momentum_update(current_state, gradients, learning_rate, momentum_coeff)
# 3. Update the state for the next iteration
current_state = next_state
if (step + 1) % 5 == 0:
print(f"Step {step+1:3d}, Params: {current_state.params:.4f}, Velocity: {current_state.velocity:.4f}, Gradient: {gradients:.4f}")
print(f"\nFinal optimized params: {current_state.params:.4f}")
Here, the state (current_state
) is a NamedTuple
(or could be a dictionary { 'params': ..., 'velocity': ... }
). The momentum_update
function takes this entire state object as input and returns a new, updated state object. The loop structure remains the same: compute gradients, call the update function with the current state, and use the returned new state for the next step.
A significant advantage of this explicit state management pattern is its natural compatibility with JAX transformations like jax.jit
. We can easily compile our update function for better performance:
# Compile the momentum update function
jitted_momentum_update = jax.jit(momentum_update, static_argnums=(2, 3)) # learning_rate and momentum are static
# --- Re-run the Optimization Loop using the JITted function ---
# Reset state
current_state = OptimizerState(params=initial_params, velocity=initial_velocity)
print("\n--- Running with JITted Update ---")
print(f"Initial state: params={current_state.params:.4f}, velocity={current_state.velocity:.4f}")
for step in range(num_steps):
gradients = grad_fn(current_state.params) # Can also jit the grad_fn if desired
# Use the compiled update function
next_state = jitted_momentum_update(current_state, gradients, learning_rate, momentum_coeff)
current_state = next_state
if (step + 1) % 5 == 0:
print(f"Step {step+1:3d}, Params: {current_state.params:.4f}, Velocity: {current_state.velocity:.4f}, Gradient: {gradients:.4f}")
print(f"\nFinal optimized params (JITted): {current_state.params:.4f}")
Because momentum_update
is a pure function (its output depends only on its inputs, with no side effects) and follows the explicit state passing pattern, jax.jit
can effectively trace and compile it. We mark learning_rate
and momentum
as static arguments because their values don't change during the loop and aren't JAX arrays, preventing unnecessary recompilations.
This example demonstrates how to handle the evolving state of an optimizer within JAX's functional paradigm. By explicitly passing state in and out of pure update functions, often using PyTrees for structure, we create code that is clear, manageable, and readily compatible with JAX's powerful transformations like jit
and grad
. This pattern is fundamental when building more complex models and training loops in JAX.
© 2025 ApX Machine Learning