While JAX provides the powerful pmap
transformation for distributing computations across multiple devices, structuring complex models and their training loops requires organization. High-level neural network libraries like Flax or Haiku offer abstractions for defining models, managing parameters, and handling state, which simplifies development significantly. Integrating pmap
with these frameworks allows us to combine the benefits of structured model building with efficient data parallelism.
The core idea remains the Single-Program, Multiple-Data (SPMD) paradigm implemented by pmap
. We write a function (typically a training step) as if it were running on a single device, but pmap
transforms it to run concurrently on multiple devices, each operating on a different slice of the input data. The frameworks help manage the model parameters and optimizer state, which need to be handled correctly in this distributed setting.
When using data parallelism with pmap
, the model itself is typically replicated across all participating devices. Each device holds a complete copy of the model's parameters. Similarly, the state of the optimizer (e.g., momentum buffers in Adam) also needs to be replicated so that each device can compute potential parameter updates based on its local data shard.
Frameworks like Flax often manage parameters and state within structured containers (like Python dictionaries or specialized dataclasses, often called "train state"). Before the pmap
ped training loop begins, we initialize the model parameters and optimizer state on the host (CPU) and then explicitly replicate this state across all available devices. JAX provides utilities to facilitate this replication.
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax # Common optimizer library used with Flax
# Assume 'model' is a Flax nn.Module instance
# Assume 'optimizer' is an Optax optimizer instance
# Example initialization (on host CPU)
key = jax.random.PRNGKey(0)
dummy_input = jnp.ones([1, 28, 28, 1]) # Example input shape
params = model.init(key, dummy_input)['params']
tx = optax.adam(learning_rate=1e-3)
optimizer_state = tx.init(params)
# Create a TrainState object (common Flax pattern)
# This bundles parameters, optimizer state, and apply_fn
class TrainState(train_state.TrainState):
pass # Can add custom fields if needed
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# Replicate the state across devices
num_devices = jax.local_device_count()
replicated_state = jax.device_put_replicate(state, jax.local_devices())
print(f"State replicated across {num_devices} devices.")
# Example check: Shape of parameters on one device
print(jax.tree_util.tree_map(lambda x: x.shape, replicated_state.params)['Dense_0']['kernel'])
# Example check: Shape of parameters across devices (notice leading dimension)
print(jax.tree_util.tree_map(lambda x: x.shape, state.params)['Dense_0']['kernel'])
Notice how jax.device_put_replicate
creates a version of the state where each leaf node (parameter array, optimizer state array) has an added leading dimension equal to the number of devices.
For data parallelism, the global batch of training data needs to be split evenly across the devices. If you have N
devices and a global batch size of B
, each device will process a local batch of size B // N
. This sharding needs to happen before the data is passed to the pmap
ped function.
global_batch_size = 64
local_batch_size = global_batch_size // num_devices
# Assume 'global_images' and 'global_labels' are NumPy arrays
# with shape [global_batch_size, ...]
def shard_batch(batch):
"""Reshapes and shards data across devices."""
return jax.tree_util.tree_map(
lambda x: x.reshape((num_devices, local_batch_size) + x.shape[1:]),
batch
)
# Example data (replace with actual data loading)
global_images = jnp.ones([global_batch_size, 28, 28, 1])
global_labels = jnp.ones([global_batch_size], dtype=jnp.int32)
batch = {'image': global_images, 'label': global_labels}
sharded_batch = shard_batch(batch)
# Verify shapes
print("Global image shape:", global_images.shape)
print("Sharded image shape:", sharded_batch['image'].shape)
# Output should show: Sharded image shape: (num_devices, local_batch_size, 28, 28, 1)
The shard_batch
utility uses jax.tree_util.tree_map
to handle arbitrary batch structures (like dictionaries) and reshapes each data array to have a leading dimension matching the number of devices.
The core of the distributed training loop is the function that performs a single optimization step. This function typically calculates the loss, computes gradients, and determines the parameter updates for a single device's local batch. When using frameworks, this often involves calling the model's apply
method and using standard JAX automatic differentiation (jax.value_and_grad
).
A significant aspect is handling gradient aggregation. Since each device computes gradients based only on its local data shard, these gradients need to be averaged across all devices before updating the replicated model parameters. This ensures that the parameter update reflects the gradient of the loss over the entire global batch. The jax.lax.pmean
collective operation is used for this purpose within the pmap
ped function.
def compute_loss(params, batch, apply_fn):
"""Computes cross-entropy loss for a batch."""
logits = apply_fn({'params': params}, batch['image'])
one_hot_labels = jax.nn.one_hot(batch['label'], num_classes=10)
loss = -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)
return jnp.mean(loss)
def train_step(state, batch):
"""Performs one training step on a single device's data shard."""
# Calculate loss and gradients for the local batch
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(state.params, batch, state.apply_fn)
# **** Crucial: Average gradients across devices ****
# 'batch_axis' is the name we give to the pmap-ed dimension
averaged_grads = jax.lax.pmean(grads, axis_name='batch_axis')
# Update the state using the averaged gradients
new_state = state.apply_gradients(grads=averaged_grads)
# Can also compute and aggregate metrics here (e.g., accuracy)
# metrics = {'loss': loss, 'accuracy': compute_accuracy(logits, batch['label'])}
# averaged_metrics = jax.lax.pmean(metrics, axis_name='batch_axis')
return new_state, loss # Return updated state and local loss
# Now, pmap the train_step function
# Specify the axis_name used in pmean
p_train_step = jax.pmap(train_step, axis_name='batch_axis')
# --- In the training loop ---
# Assume 'sharded_batch' is prepared as shown before
# replicated_state holds the state replicated across devices
# Execute the parallel training step
replicated_state, local_losses = p_train_step(replicated_state, sharded_batch)
# local_losses will have shape (num_devices,)
# Average loss across devices for logging (optional, done on host)
avg_loss = jnp.mean(local_losses)
print(f"Average loss across devices: {avg_loss:.4f}")
# The replicated_state now contains the updated parameters and
# optimizer state, consistent across all devices.
In this example:
compute_loss
defines how to calculate the loss for a given set of parameters and a batch, using the model's apply_fn
stored in the state
.train_step
calculates the loss and gradients using jax.value_and_grad
.jax.lax.pmean(grads, axis_name='batch_axis')
averages the gradient pytree (grads
) across all devices participating in the pmap
operation identified by the name 'batch_axis'
.state.apply_gradients
is a method provided by flax.training.train_state.TrainState
which uses the optimizer (state.tx
) to update the parameters (state.params
) using the provided gradients.jax.pmap(train_step, axis_name='batch_axis')
creates the parallel version of train_step
. The axis_name
argument is essential; it connects the pmap
operation to the collective operations (like pmean
) used inside the function.p_train_step
is called with the replicated_state
and sharded_batch
, JAX executes train_step
on each device with its corresponding slice of state and data. The pmean
operation synchronizes the devices to average the gradients.replicated_state
(consistent across all devices) and the loss calculated on each device's local shard.Stochastic operations like dropout require careful handling of random number generator (RNG) keys in a distributed setting. Simply replicating the same key would lead to identical dropout masks on all devices, negating the benefit of randomness.
A common strategy is to split the main PRNG key on the host and provide a different subkey to each device. Frameworks like Flax often provide mechanisms to handle this automatically when initializing or applying models, usually requiring you to pass specific RNG streams (e.g., one for 'params' initialization, one for 'dropout'). When using pmap
, you need to ensure these per-device keys are correctly passed into the pmap
ped function. Often, this involves splitting a key outside the pmap
ped function and including the resulting sharded keys as part of the input arguments.
# --- Outside the training loop ---
main_key = jax.random.PRNGKey(42)
# --- Inside the training loop ---
# Split key for the current step
step_key, main_key = jax.random.split(main_key)
# Split key across devices for operations like dropout within the model
dropout_keys = jax.random.split(step_key, num_devices)
# Modify train_step to accept and use dropout keys
def train_step_with_rng(state, batch, dropout_rng):
# ... inside compute_loss or apply_fn ...
# logits = apply_fn({'params': params}, batch['image'],
# rngs={'dropout': dropout_rng})
# ... rest of the function ...
# Remember pmean for grads etc.
pass # Placeholder for the modified function body
p_train_step_with_rng = jax.pmap(train_step_with_rng, axis_name='batch_axis')
# Call with sharded dropout keys
# replicated_state, local_losses = p_train_step_with_rng(replicated_state, sharded_batch, dropout_keys)
Combining pmap
with frameworks like Flax or Haiku provides a scalable and organized way to implement data-parallel training. The framework handles model definition and state management, while pmap
orchestrates the distribution and execution across multiple devices, including the necessary gradient synchronization using collective operations. This pattern is fundamental for training large models efficiently on modern accelerator hardware.
© 2025 ApX Machine Learning