As introduced, applying a function designed for a single data point across a whole batch is a frequent task. While Python loops or NumPy's inherent vectorization can handle some cases, managing dimensions and loops explicitly can obscure the core logic, especially for more complex functions or when dealing with multiple batched inputs.
JAX provides jax.vmap
("vectorizing map") as a function transformation specifically built for this. Think of vmap
as a way to automatically add a batch dimension to your function without rewriting its internal logic. You write the function as if it operates on a single example, and vmap
transforms it into a function that operates efficiently on an entire batch of examples.
Let's see it in action. Suppose we have a function that calculates the sum of the squares of the elements in a single vector:
import jax
import jax.numpy as jnp
# Function designed for a single vector input
def sum_of_squares(vector):
# This function expects a 1D array (vector)
print(f"Running sum_of_squares for a vector of shape: {vector.shape}")
return jnp.sum(vector**2)
# Example single vector
single_vector = jnp.array([1., 2., 3.])
result_single = sum_of_squares(single_vector)
print(f"Result for single vector: {result_single}")
# Expected output: Running sum_of_squares for a vector of shape: (3,)
# Expected output: Result for single vector: 14.0 (1^2 + 2^2 + 3^2)
Now, imagine we have a batch of these vectors, perhaps represented as a matrix where each row is a vector we want to process independently:
# A batch of 4 vectors, each of size 3
batch_of_vectors = jnp.array([
[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.],
[0., 1., 0.]
])
print(f"Batch shape: {batch_of_vectors.shape}")
# Expected output: Batch shape: (4, 3)
How can we apply sum_of_squares
to each row (vector) in batch_of_vectors
? Without vmap
, we might write a loop:
# Manual loop approach (demonstration, not recommended in JAX)
manual_results = []
for i in range(batch_of_vectors.shape[0]):
vector = batch_of_vectors[i]
result = sum_of_squares(vector) # Function runs (and potentially traces) for each row
manual_results.append(result)
manual_output = jnp.stack(manual_results)
print(f"Manual loop output: {manual_output}")
# Expected output:
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Manual loop output: [ 14. 77. 194. 1.]
This works, but it's verbose, uses Python-level iteration which is often slow, and in JAX, calling the function repeatedly like this inside a loop can lead to inefficient tracing if combined with jit
(as discussed in Chapter 2).
Enter jax.vmap
. We can create a vectorized version of our function simply by wrapping it:
# Create a vectorized version using vmap
vectorized_sum_of_squares = jax.vmap(sum_of_squares)
# Apply it directly to the batch
vmap_output = vectorized_sum_of_squares(batch_of_vectors)
print(f"vmap output: {vmap_output}")
# Expected output:
# Running sum_of_squares for a vector of shape: (3,) <-- Note: This prints only once!
# vmap output: [ 14. 77. 194. 1.]
Notice a few things:
sum_of_squares
. We simply applied jax.vmap
to it.batch_of_vectors
(shape (4, 3)) to the vectorized_sum_of_squares
function.[ 14. 77. 194. 1.]
contains the sum of squares for each row, matching the manual loop.sum_of_squares
likely executed only once during JAX's tracing process for vmap
, not once per row as in the Python loop. vmap
understands the batching pattern and optimizes the execution.Conceptually, jax.vmap
takes the function sum_of_squares
(which expects input shape (3,)
and produces output shape ()
) and transforms it into vectorized_sum_of_squares
. This new function understands that if you give it an input with shape (4, 3)
, it should map the original function over the leading axis (axis 0), effectively applying it to each of the 4 vectors of size 3. It then stacks the scalar results ()
back together along a new leading axis, producing an output of shape (4,)
.
By default, vmap
maps over the first axis (axis 0) of the input arguments and stacks the results along a new first axis for the output. This default behavior is often exactly what's needed for common batching scenarios in machine learning and numerical computing.
In the following sections, we'll examine how to customize this mapping behavior using in_axes
and out_axes
for more complex scenarios involving multiple arguments and different batching dimensions.
© 2025 ApX Machine Learning