One of the significant advantages of JAX's design is the composability of its function transformations. You've learned about jit
for compilation, grad
for differentiation, and vmap
for vectorization. Now, let's see how these transformations work together, enabling powerful and efficient computational patterns common in machine learning and scientific computing. Combining them allows you to, for instance, compute gradients over batches of data and compile the entire operation for maximum performance on accelerators.
vmap
and jit
You often want to both vectorize a function using vmap
and compile it using jit
. This is straightforward: you can simply apply one transformation after the other. The typical pattern is to apply vmap
first to create the vectorized version of your function, and then apply jit
to compile that vectorized function.
import jax
import jax.numpy as jnp
import time
# Define a function that works on single data points
def predict(params, x):
# A simple linear model: w*x + b
w, b = params
return w * x + b
# Create some example parameters and a batch of data
key = jax.random.PRNGKey(0)
params = (jnp.array(2.0), jnp.array(1.0)) # (w, b)
batch_x = jnp.arange(10000.0)
# 1. Vectorize the predict function using vmap
# Map over batch_x (axis 0), but keep params the same (None)
batched_predict_vmap = jax.vmap(predict, in_axes=(None, 0))
# 2. Compile the vectorized function using jit
jitted_batched_predict = jax.jit(batched_predict_vmap)
# --- Timing Comparison ---
# Run the vmap version (without JIT)
start_time = time.time()
result_vmap = batched_predict_vmap(params, batch_x).block_until_ready()
duration_vmap = time.time() - start_time
print(f"vmap only duration: {duration_vmap:.6f} seconds")
# Run the JIT-compiled vmap version (includes compile time on first run)
start_time = time.time()
result_jit_vmap = jitted_batched_predict(params, batch_x).block_until_ready()
duration_jit_vmap_first = time.time() - start_time
print(f"jit(vmap(...)) duration (first run): {duration_jit_vmap_first:.6f} seconds")
# Run the JIT-compiled vmap version again (should be faster)
start_time = time.time()
result_jit_vmap_again = jitted_batched_predict(params, batch_x).block_until_ready()
duration_jit_vmap_second = time.time() - start_time
print(f"jit(vmap(...)) duration (second run): {duration_jit_vmap_second:.6f} seconds")
# Check results are the same
print(f"Results match: {jnp.allclose(result_vmap, result_jit_vmap)}")
Why jit(vmap(f))
?
Wrapping vmap
with jit
(jit(vmap(f))
) is generally the preferred order.
vmap(f)
first creates a new Python function that applies f
across the mapped axes of the inputs. Internally, vmap
transforms the JAX primitives within f
to operate on batches.jit(...)
then takes this vectorized function and compiles it using XLA into optimized kernels for your target hardware (CPU/GPU/TPU). This allows the compiler to see the entire batched operation and optimize it as a whole.While vmap(jit(f))
is also possible, it compiles the inner function f
first and then vectorizes the call to the compiled function. This might be less efficient as the vectorization logic operates "outside" the compiled kernel, potentially leading to less optimal hardware utilization compared to compiling the already-vectorized code. For most common use cases, jit(vmap(f))
provides better performance.
vmap
and grad
Another common requirement, especially in machine learning, is to compute gradients not just for a single data point, but for an entire batch. vmap
makes this easy to express. You can compute the gradient of your function using grad
and then vectorize the resulting gradient function using vmap
.
Let's say you have a function that calculates a loss based on parameters and a single data point. You often want the gradient of this loss with respect to the parameters, calculated for each item in a batch.
import jax
import jax.numpy as jnp
# Example: Squared error loss for a single data point
def squared_error(params, x, y_true):
w, b = params
y_pred = w * x + b
loss = (y_pred - y_true)**2
return loss
# Function to compute gradient w.r.t. params (arg 0) for a single data point
grad_loss_single = jax.grad(squared_error, argnums=0)
# Example parameters and a batch of data (x, y_true)
key = jax.random.PRNGKey(1)
params = (jnp.array(2.0), jnp.array(1.0)) # (w, b)
batch_x = jnp.linspace(0, 1, 5)
batch_y_true = 2.5 * batch_x + 0.5 # True relationship: w=2.5, b=0.5
# Vectorize the gradient function:
# Map over batch_x (axis 0) and batch_y_true (axis 0)
# Keep params the same for all calculations (None)
grad_loss_batch = jax.vmap(grad_loss_single, in_axes=(None, 0, 0))
# Compute gradients for the entire batch
batch_gradients = grad_loss_batch(params, batch_x, batch_y_true)
# batch_gradients will be a tuple (dw, db) where dw and db are arrays
# with the gradient computed for each item in the batch.
print("Parameters (w, b):", params)
print("Batch x:", batch_x)
print("Batch y_true:", batch_y_true)
print("\nPer-example gradients (dw):", batch_gradients[0])
print("Per-example gradients (db):", batch_gradients[1])
print("\nShape of dw:", batch_gradients[0].shape) # Should be (5,)
print("Shape of db:", batch_gradients[1].shape) # Should be (5,)
Here, grad_loss_single
computes the gradient (∂w∂L,∂b∂L) for a single (x,ytrue) pair. Applying vmap
to grad_loss_single
effectively loops this gradient computation over the batch_x
and batch_y_true
arrays, returning the gradients for each example. The shape of the resulting gradients reflects the batch dimension added by vmap
.
This pattern is useful for algorithms that require per-example gradients. More commonly in deep learning, you might want the mean gradient across the batch. You could compute the per-example gradients as above and then average them, or you could define your loss function to compute the mean loss before taking the gradient:
import jax
import jax.numpy as jnp
# Example: Mean squared error loss over a batch
def mean_squared_error(params, batch_x, batch_y_true):
# Use vmap *inside* the loss function for prediction
# Note: This is often less explicit than vmap(grad(...))
batched_predict = jax.vmap(predict, in_axes=(None, 0))
batch_y_pred = batched_predict(params, batch_x)
loss = jnp.mean((batch_y_pred - batch_y_true)**2)
return loss
# Compute the gradient of the *mean* loss w.r.t params
grad_mean_loss = jax.grad(mean_squared_error, argnums=0)
# Compute the single gradient vector representing the average gradient over the batch
mean_batch_gradient = grad_mean_loss(params, batch_x, batch_y_true)
print("\n--- Mean Gradient ---")
print("Mean batch gradient (dw, db):", mean_batch_gradient)
print("Shape of mean dw:", mean_batch_gradient[0].shape) # Should be () - scalar
print("Shape of mean db:", mean_batch_gradient[1].shape) # Should be () - scalar
# You can verify this matches the mean of the per-example gradients
print("Mean of per-example dw:", jnp.mean(batch_gradients[0]))
print("Mean of per-example db:", jnp.mean(batch_gradients[1]))
print(f"Mean gradients match: {jnp.allclose(mean_batch_gradient[0], jnp.mean(batch_gradients[0]))}")
The pattern vmap(grad(f, ...), ...)
is often clearer when you explicitly need per-example results, while grad(mean_loss_fn, ...)
is standard for typical gradient descent optimization where only the average gradient across the batch is needed.
vmap
, grad
, and jit
Now, let's combine all three. This is a fundamental pattern for efficient training of machine learning models in JAX. You typically want to:
grad
.vmap
.jit
for high performance on accelerators.The most common and often most performant ordering for computing the average gradient over a batch is jit(grad(mean_loss_fn))
. If you need per-example gradients efficiently, the pattern is jit(vmap(grad(single_loss_fn)))
.
Let's illustrate the jit(vmap(grad(f)))
pattern for efficient per-example gradients:
import jax
import jax.numpy as jnp
import time
# Use the single-example squared_error from before
def squared_error(params, x, y_true):
w, b = params
y_pred = w * x + b
loss = (y_pred - y_true)**2
return loss
# 1. Get the gradient function for a single example w.r.t. params
grad_loss_single = jax.grad(squared_error, argnums=0)
# 2. Vectorize the single-example gradient function
# Map over batch_x (axis 0) and batch_y_true (axis 0)
# Keep params fixed (None)
vmap_grad_loss = jax.vmap(grad_loss_single, in_axes=(None, 0, 0))
# 3. Compile the vectorized gradient function
jit_vmap_grad_loss = jax.jit(vmap_grad_loss)
# Prepare larger batch data for timing
key = jax.random.PRNGKey(42)
large_batch_size = 100000
params = (jnp.array(2.0), jnp.array(1.0)) # (w, b)
large_batch_x = jax.random.uniform(key, (large_batch_size,))
large_batch_y_true = 2.5 * large_batch_x + 0.5 + 0.1 * jax.random.normal(key, (large_batch_size,))
# --- Timing ---
# Run the compiled function (includes compile time)
start_time = time.time()
per_example_grads = jit_vmap_grad_loss(params, large_batch_x, large_batch_y_true)
# Wait for computation to finish before stopping timer
per_example_grads[0].block_until_ready()
per_example_grads[1].block_until_ready()
duration_first = time.time() - start_time
print(f"jit(vmap(grad(...))) duration (first run): {duration_first:.6f} seconds")
# Run again (should be much faster)
start_time = time.time()
per_example_grads_again = jit_vmap_grad_loss(params, large_batch_x, large_batch_y_true)
per_example_grads_again[0].block_until_ready()
per_example_grads_again[1].block_until_ready()
duration_second = time.time() - start_time
print(f"jit(vmap(grad(...))) duration (second run): {duration_second:.6f} seconds")
print(f"\nShape of per-example dw: {per_example_grads[0].shape}")
print(f"Shape of per-example db: {per_example_grads[1].shape}")
By composing jit
, vmap
, and grad
, you create highly optimized functions that compute batched gradients efficiently on modern hardware, forming the core of many JAX-based machine learning workflows.
jax.debug.print
or disabling jit
temporarily can help isolate issues.vmap
is efficient, vectorizing over very large batches can consume significant memory, especially on GPUs/TPUs which have fixed memory limits. Be mindful of your batch sizes relative to available device memory.jit(vmap(grad(f)))
and jit(grad(mean_loss))
are common, effective patterns, but understanding why they work helps you adapt them to new situations.Mastering the combination of vmap
, grad
, and jit
is essential for writing concise, performant JAX code, particularly when dealing with the batch processing inherent in deep learning and other data-parallel computations.
© 2025 ApX Machine Learning