Training large neural networks often requires batch sizes that exceed the available memory on a single accelerator (GPU or TPU). While distributing the computation across multiple devices using pmap
(as discussed in Chapter 3) helps, sometimes even the per-device batch size needed for stable training or optimal performance is too large for the device's memory. Gradient accumulation provides a direct solution to this challenge.
The core idea is to simulate a large batch by processing several smaller batches sequentially, accumulating their gradients, and then performing a single optimizer step using the aggregated gradient information. This effectively decouples the batch size used for the weight update from the batch size that must fit into memory at any one time.
Imagine you want to train with an effective batch size of B, but your accelerator can only handle a smaller batch size, let's call it the micro-batch size b, where B=N×b. Gradient accumulation achieves this by performing the following steps:
This process computes the gradient:
∇θLeff=N1i=1∑N∇θL(micro-batchi;θ)This averaged gradient approximates the gradient that would have been computed using the full effective batch of size B. The key benefit is that only one micro-batch needs to reside in accelerator memory at a time during the gradient computation phase.
Implementing gradient accumulation in JAX typically involves modifying the training step function. Instead of computing gradients and applying the update in one go, we separate these steps and introduce a loop.
Let's consider a typical JAX training step function that takes the model state (parameters, optimizer state), and a batch of data, computes loss and gradients, and returns the updated state and metrics.
import jax
import jax.numpy as jnp
import optax # Example optimizer library
# Assume 'model', 'loss_fn' are defined elsewhere
# 'params' are model parameters
# 'opt_state' is the optimizer state
@jax.jit
def train_step(params, opt_state, batch):
"""Performs a single training step WITHOUT gradient accumulation."""
def compute_loss(p):
logits = model.apply({'params': p}, batch['image'])
loss = loss_fn(logits, batch['label'])
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
metrics = {'loss': loss}
return params, opt_state, metrics
To incorporate gradient accumulation, we need to manage the accumulated gradients and loop over micro-batches.
# Assume accumulation_steps = N (integer > 1)
@jax.jit
def micro_batch_step(params, micro_batch):
"""Computes gradients for a single micro-batch."""
def compute_loss(p):
logits = model.apply({'params': p}, micro_batch['image'])
loss = loss_fn(logits, micro_batch['label'])
return loss, logits # Also return logits for potential metrics
# Use value_and_grad to get loss and gradients
(loss, _), grads = jax.value_and_grad(compute_loss, has_aux=True)(params)
# Note: Returning loss here might be tricky if averaging across steps.
# Often, metrics are calculated based on the full effective batch.
return grads, loss
# This function will now handle the accumulation loop and optimizer update
# It's often NOT fully JIT-compiled because the loop handles data loading.
# However, the micro_batch_step IS JIT-compiled.
def accumulated_train_step(params, opt_state, data_iterator, accumulation_steps):
"""Performs a training step WITH gradient accumulation."""
# 1. Initialize accumulated gradients (as zeros shaped like params)
accumulated_grads = jax.tree_util.tree_map(jnp.zeros_like, params)
total_loss = 0.0
# 2. Micro-batch Loop
for _ in range(accumulation_steps):
micro_batch = next(data_iterator) # Fetch next micro-batch
# Compute gradients for this micro-batch (uses the JIT-compiled function)
grads, loss = micro_batch_step(params, micro_batch)
# Accumulate gradients
accumulated_grads = jax.tree_util.tree_map(lambda acc, g: acc + g, accumulated_grads, grads)
total_loss += loss
# 3. Parameter Update
# Average the gradients
averaged_grads = jax.tree_util.tree_map(lambda g: g / accumulation_steps, accumulated_grads)
average_loss = total_loss / accumulation_steps
# Apply optimizer update
# This part can often be JIT-compiled separately if needed
@jax.jit
def apply_update(p, o_state, avg_grads):
updates, new_o_state = optimizer.update(avg_grads, o_state, p)
new_p = optax.apply_updates(p, updates)
return new_p, new_o_state
params, opt_state = apply_update(params, opt_state, averaged_grads)
metrics = {'loss': average_loss} # Report average loss over micro-batches
return params, opt_state, metrics
In practice, structuring this within a larger training loop involves creating a data iterator that yields micro-batches and calling accumulated_train_step
.
A more functional approach using jax.lax.scan
can encapsulate the accumulation loop within a single JIT-compiled function, but requires careful state management and structuring the data loading appropriately. For clarity, the explicit loop shown above often illustrates the concept more directly.
The process of gradient accumulation involves initializing gradients, looping through micro-batches to compute and sum gradients, averaging the result, applying the optimizer update, and resetting for the next cycle.
pmap
Gradient accumulation combines naturally with data parallelism using pmap
. When using pmap
, each device processes its own portion of a micro-batch. The gradient accumulation loop occurs independently on each device.
lax.pmean
is typically used to average the accumulated gradients across all devices. This ensures all devices compute the update based on the gradient information from the entire effective batch (B).This means each device still only needs memory for its share of a single micro-batch b/num_devices, while the effective batch size used for the gradient update calculation is B.
pmap
and potentially frameworks like Flax or Haiku requires careful handling of state and data flow.Gradient accumulation is a fundamental technique for pushing the boundaries of model scale when faced with memory limitations. It's often used in combination with other methods like gradient checkpointing and mixed precision to train state-of-the-art models.
© 2025 ApX Machine Learning