As discussed in the chapter introduction, training large machine learning models effectively requires careful management of not just the model parameters themselves, but also associated state like optimizer statistics and pseudo-random number generator (PRNG) keys. JAX's functional programming paradigm, where functions ideally have no side effects, means that state must be explicitly managed: passed into functions and returned as output. While this might seem verbose initially, it offers significant clarity and simplifies debugging, especially in complex distributed settings.
In JAX, model parameters, optimizer state (e.g., momentum buffers, learning rate schedules), batch normalization statistics, and PRNG keys are typically represented as PyTrees. A PyTree is simply any Python object that JAX can treat as a nested structure of containers (like lists, tuples, dictionaries) and leaves (like JAX arrays or standard Python types).
import jax
import jax.numpy as jnp
import optax # Common JAX optimizer library
# Example structure for parameters (could be much deeper)
params = {
'encoder': {
'layer_1': {'w': jnp.ones((128, 256)), 'b': jnp.zeros(256)},
'layer_norm': {'scale': jnp.ones(256), 'bias': jnp.zeros(256)}
},
'decoder': {
'output': {'w': jnp.ones((256, 10)), 'b': jnp.zeros(10)}
}
}
# Example optimizer state (structure often mirrors params)
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
# A single PRNG key
key = jax.random.PRNGKey(0)
# You might bundle these together for convenience
training_state = {
'params': params,
'opt_state': opt_state,
'rng_key': key,
'step': 0
}
# Verify it's a PyTree
leaves, treedef = jax.tree_util.tree_flatten(training_state)
print(f"Number of leaf nodes (arrays, scalars): {len(leaves)}")
# Output: Number of leaf nodes (arrays, scalars): 11 (depends on exact structure and optimizer)
Using PyTrees is fundamental because JAX transformations (jit
, grad
, vmap
, pmap
) are designed to operate seamlessly over these structures. When you apply jax.grad
to a function that takes and returns PyTrees, JAX computes gradients with respect to all numerical leaf nodes (arrays) in the input PyTree(s) specified. Similarly, pmap
automatically replicates or distributes PyTree structures across devices based on its arguments. This makes managing potentially complex nested state structures much more systematic.
While you can manage state manually using dictionaries or custom classes, libraries like Flax and Haiku provide higher-level abstractions specifically designed for neural networks, streamlining state management significantly.
Flax often encourages grouping related state components into dedicated objects, commonly using patterns like TrainState
. This class typically bundles the model parameters (params
), the optimizer state (opt_state
), the current training step (step
), and sometimes the model definition itself (apply_fn
).
# Conceptual Flax example
from flax.training import train_state
import optax
class SimpleModel(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
class TrainState(train_state.TrainState):
# Optionally add batch stats or other state here
batch_stats: Any = None # Example for BatchNorm
# Initialization
key = jax.random.PRNGKey(0)
model = SimpleModel(features=10)
dummy_input = jnp.ones([1, 128])
params = model.init(key, dummy_input)['params']
optimizer = optax.adam(1e-3)
# Create the state object
state = TrainState.create(
apply_fn=model.apply, # Function to run the model
params=params,
tx=optimizer # Optimizer transformation
)
# Inside a training step, you'd update this state immutably
# new_state = state.apply_gradients(grads=grads)
The key idea remains the same: state is held explicitly in an object (a PyTree), and updates produce new state objects rather than modifying them in place. This aligns perfectly with JAX's functional approach.
Haiku takes a slightly different approach using hk.transform
or hk.transform_with_state
. It separates the pure function logic from the parameters (params
) and mutable state (state
, e.g., for batch norm statistics).
import haiku as hk
import optax
def forward_fn(x, is_training):
net = hk.Sequential([
hk.Linear(512), jax.nn.relu,
hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99),
hk.Linear(10)
])
return net(x, is_training=is_training)
# Transform the function to handle state
forward = hk.transform_with_state(forward_fn)
key = hk.PRNGSequence(0)
dummy_input = jnp.ones([1, 128])
# Initialize parameters and mutable state
params, state = forward.init(next(key), dummy_input, is_training=True)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)
# Apply function (example, is_training=True updates batch norm state)
# logits, new_state = forward.apply(params, state, next(key), batch_data, is_training=True)
# grads, new_state = grad_fn(params, state, ...) # grad_fn needs to handle state
# updates, new_opt_state = optimizer.update(grads, opt_state, params)
# new_params = optax.apply_updates(params, updates)
Haiku requires you to explicitly manage the params
and state
PyTrees returned by init
and apply
. Again, the functional principle holds: state is passed in, and updated state is returned.
pmap
)When scaling training using pmap
for data parallelism, managing state correctly across multiple devices becomes essential.
In a typical
pmap
data-parallel setup, parameters and optimizer state are replicated, data is sharded, and PRNG keys must be unique per device. The mapped function computes per-device results like loss and gradients.
Parameters and Optimizer State: For standard data parallelism, the model parameters (params
) and the associated optimizer state (opt_state
) are usually identical across all devices. You initialize them once on the host, and pmap
automatically broadcasts (replicates) these PyTrees to each device. When gradients are computed on each device, they need to be aggregated (e.g., averaged using lax.pmean
) before updating the parameters and optimizer state. This update step is often also performed within the pmap
'd function, ensuring the replicated state remains consistent across devices.
PRNG Keys: This is a frequent source of subtle errors. If you use the same PRNG key on all devices for operations like dropout or data augmentation, all devices will produce the same random numbers, defeating the purpose of stochasticity or leading to correlated results. The correct approach is to generate a main key on the host and then split it into unique sub-keys for each device before calling pmap
.
num_devices = jax.local_device_count()
key = jax.random.PRNGKey(42)
# Split the key once on the host
device_keys = jax.random.split(key, num_devices)
# Example pmapped function (simplified)
@jax.pmap
def train_step_pmap(params, opt_state, local_key, batch):
# Use the per-device key 'local_key' inside
dropout_key, new_local_key = jax.random.split(local_key)
# ... perform forward pass with dropout using dropout_key ...
# ... compute gradients ...
# grads = ...
# grads = lax.pmean(grads, axis_name='devices') # Aggregate grads
# ... update params and opt_state ...
# return loss, new_params, new_opt_state, new_local_key
# Call pmap, passing the array of device-specific keys
# loss, params, opt_state, device_keys = train_step_pmap(params, opt_state, device_keys, sharded_batch)
Create an array of unique PRNG keys, one for each device, by splitting a master key before passing them into
pmap
.
Each execution of the pmap
'd function on a specific device will then receive its corresponding unique key from the device_keys
array. Remember to also return the updated key from the pmap
'd function so you can continue the PRNG sequence correctly in the next step.
Training large models can take hours, days, or even weeks. It's essential to periodically save the training state (checkpointing) to disk. This allows you to resume training later if interrupted and saves the final trained model.
What needs to be saved? Typically, you need:
params
): The learned weights and biases.opt_state
): Crucial for resuming training correctly, especially for optimizers with momentum or adaptive learning rates (like Adam).Libraries within the JAX ecosystem provide tools for this. For instance, orbax.checkpoint
is becoming a standard solution, offering features like asynchronous checkpointing (saving state in the background without stalling training) and flexible ways to structure saved data. Frameworks like Flax also have built-in serialization utilities (flax.training.checkpoints
).
# Conceptual example using Orbax (simplified)
import orbax.checkpoint as ocp
# Assume 'state' is a PyTree containing params, opt_state, step etc.
# For pmap, state might be replicated. Usually save from device 0.
# state_to_save = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state)) # Get state from device 0
checkpointer = ocp.StandardCheckpointer() # Or AsyncCheckpointer
save_path = '/path/to/checkpoints/step_10000'
checkpointer.save(save_path, args=ocp.args.StandardSave(state))
# To restore:
# restored_state = checkpointer.restore(save_path, args=ocp.args.StandardRestore(state))
# If using pmap, replicate the restored state across devices
# state = jax.device_put_replicated(restored_state, jax.local_devices())
Effectively managing parameters and state is a foundational element of building and training large-scale models in JAX. By leveraging PyTrees and the abstractions provided by ecosystem libraries, and by carefully handling state distribution and persistence, you can build robust and scalable training loops.
© 2025 ApX Machine Learning