As discussed conceptually, gradient checkpointing is a valuable technique for reducing the memory footprint of your models during training. It achieves this by avoiding the storage of intermediate activations from designated parts of your network during the forward pass. Instead, these activations are recomputed when they are needed for the gradient calculation during the backward pass. This trades increased computational cost for decreased memory usage, often enabling the training of much larger models than would otherwise fit on the accelerator.
The primary tool for this in JAX is jax.checkpoint
, also available as jax.remat
. Let's see how to apply it in practice.
Consider a function f composed of two sub-functions, f2∘f1, meaning z=f(x)=f2(f1(x)). Let y=f1(x).
This recomputation avoids storing potentially large intermediate tensors like y throughout the entire forward and backward pass.
jax.checkpoint
The jax.checkpoint
function acts as a wrapper around the function you want to apply checkpointing to. Its basic usage involves passing the function to be checkpointed:
checkpointed_f1 = jax.checkpoint(f1)
# Now use checkpointed_f1 instead of f1 in your model
y = checkpointed_f1(x)
z = f2(y)
When the gradient of the overall computation (involving z
) is calculated, JAX's autodiff system knows that checkpointed_f1
requires recomputation during the backward pass.
Let's define a simple sequence of operations that might represent a block within a larger neural network. We'll make the intermediate dimension large to simulate memory pressure.
import jax
import jax.numpy as jnp
import time
# Define a block of computation
def compute_block(x, W1, W2):
"""A block with a potentially large intermediate activation."""
y = jnp.dot(x, W1)
y = jax.nn.gelu(y) # GELU activation
# 'y' is the intermediate activation we might want to avoid storing
z = jnp.dot(y, W2)
return z
# Define a dummy loss function using this block
def loss_fn(x, W1, W2, targets):
z = compute_block(x, W1, W2)
# Simple mean squared error loss
loss = jnp.mean((z - targets)**2)
return loss
# Gradient function without checkpointing
grad_fn_standard = jax.jit(jax.value_and_grad(loss_fn, argnums=(1, 2)))
# --- Now, define the checkpointed version ---
# Apply checkpointing to the compute_block
compute_block_checkpointed = jax.checkpoint(compute_block)
# Define the loss using the checkpointed block
def loss_fn_checkpointed(x, W1, W2, targets):
# Use the checkpointed version here
z = compute_block_checkpointed(x, W1, W2)
loss = jnp.mean((z - targets)**2)
return loss
# Gradient function with checkpointing
grad_fn_checkpointed = jax.jit(jax.value_and_grad(loss_fn_checkpointed, argnums=(1, 2)))
# --- Setup Data ---
key = jax.random.PRNGKey(42)
batch_size = 64
input_dim = 512
hidden_dim = 8192 # Large hidden dimension
output_dim = 512
key, x_key, w1_key, w2_key, t_key = jax.random.split(key, 5)
x = jax.random.normal(x_key, (batch_size, input_dim))
W1 = jax.random.normal(w1_key, (input_dim, hidden_dim)) * 0.02
W2 = jax.random.normal(w2_key, (hidden_dim, output_dim)) * 0.02
targets = jax.random.normal(t_key, (batch_size, output_dim))
# --- Run and Compare ---
print("Running standard version (compilation + execution)...")
start_time = time.time()
loss_std, (dW1_std, dW2_std) = grad_fn_standard(x, W1, W2, targets)
# Ensure computation finishes before stopping timer
loss_std.block_until_ready()
dW1_std.block_until_ready()
dW2_std.block_until_ready()
end_time = time.time()
time_std = end_time - start_time
print(f"Standard Loss: {loss_std:.4f}")
print(f"Standard Time: {time_std:.4f} seconds")
print("\nRunning checkpointed version (compilation + execution)...")
start_time = time.time()
loss_ckpt, (dW1_ckpt, dW2_ckpt) = grad_fn_checkpointed(x, W1, W2, targets)
# Ensure computation finishes
loss_ckpt.block_until_ready()
dW1_ckpt.block_until_ready()
dW2_ckpt.block_until_ready()
end_time = time.time()
time_ckpt = end_time - start_time
print(f"Checkpointed Loss: {loss_ckpt:.4f}")
print(f"Checkpointed Time: {time_ckpt:.4f} seconds")
# Verify gradients are close (should be almost identical)
print("\nComparing gradients...")
print(f"Max absolute difference W1: {jnp.max(jnp.abs(dW1_std - dW1_ckpt)):.2e}")
print(f"Max absolute difference W2: {jnp.max(jnp.abs(dW2_std - dW2_ckpt)):.2e}")
grad_fn_checkpointed
version did not need to store the potentially very large activation y
(size batch_size * hidden_dim
) during the forward pass for later use in the backward pass. It recomputed y
using compute_block
during the backward gradient calculation for W2
and y
. If hidden_dim
is large, this saving can be substantial.compute_block
. The exact time difference depends heavily on the relative cost of the forward computation versus the backward computation and the hardware used.The diagram below illustrates the difference in the backward pass:
The diagram shows the standard backward pass reading the stored activation 'Y', while the checkpointed backward pass recomputes 'Y' using the stored input 'X' just before it's needed for the gradient calculation of the second layer (Dot(Y, W2)).
Gradient checkpointing is most effective when:
You can apply jax.checkpoint
selectively to specific layers or blocks within your model, often requiring some experimentation to find the optimal balance between memory savings and computational overhead for your specific architecture and hardware. Frameworks like Flax provide convenient wrappers (e.g., flax.linen.remat
) to apply checkpointing to specific modules.
This practical exercise demonstrates how jax.checkpoint
provides a direct way to manage the memory-compute trade-off, a significant technique for training large-scale models effectively in JAX.
© 2025 ApX Machine Learning