jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped Functionsjax.vmapApplying a function designed for a single data point across a whole batch is a frequent task. While Python loops or NumPy's inherent vectorization can handle some cases, managing dimensions and loops explicitly can obscure the core logic, especially for more complex functions or when dealing with multiple batched inputs.
JAX provides jax.vmap ("vectorizing map") as a function transformation specifically built for this. Think of vmap as a way to automatically add a batch dimension to your function without rewriting its internal logic. You write the function as if it operates on a single example, and vmap transforms it into a function that operates efficiently on an entire batch of examples.
Let's see it in action. Suppose we have a function that calculates the sum of the squares of the elements in a single vector:
import jax
import jax.numpy as jnp
# Function designed for a single vector input
def sum_of_squares(vector):
# This function expects a 1D array (vector)
print(f"Running sum_of_squares for a vector of shape: {vector.shape}")
return jnp.sum(vector**2)
# Example single vector
single_vector = jnp.array([1., 2., 3.])
result_single = sum_of_squares(single_vector)
print(f"Result for single vector: {result_single}")
# Expected output: Running sum_of_squares for a vector of shape: (3,)
# Expected output: Result for single vector: 14.0 (1^2 + 2^2 + 3^2)
Now, imagine we have a batch of these vectors, perhaps represented as a matrix where each row is a vector we want to process independently:
# A batch of 4 vectors, each of size 3
batch_of_vectors = jnp.array([
[1., 2., 3.],
[4., 5., 6.],
[7., 8., 9.],
[0., 1., 0.]
])
print(f"Batch shape: {batch_of_vectors.shape}")
# Expected output: Batch shape: (4, 3)
How can we apply sum_of_squares to each row (vector) in batch_of_vectors? Without vmap, we might write a loop:
# Manual loop approach (demonstration, not recommended in JAX)
manual_results = []
for i in range(batch_of_vectors.shape[0]):
vector = batch_of_vectors[i]
result = sum_of_squares(vector) # Function runs (and potentially traces) for each row
manual_results.append(result)
manual_output = jnp.stack(manual_results)
print(f"Manual loop output: {manual_output}")
# Expected output:
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Running sum_of_squares for a vector of shape: (3,)
# Manual loop output: [ 14. 77. 194. 1.]
This works, but it's verbose, uses Python-level iteration which is often slow, and in JAX, calling the function repeatedly like this inside a loop can lead to inefficient tracing if combined with jit (as discussed in Chapter 2).
Enter jax.vmap. We can create a vectorized version of our function simply by wrapping it:
# Create a vectorized version using vmap
vectorized_sum_of_squares = jax.vmap(sum_of_squares)
# Apply it directly to the batch
vmap_output = vectorized_sum_of_squares(batch_of_vectors)
print(f"vmap output: {vmap_output}")
# Expected output:
# Running sum_of_squares for a vector of shape: (3,) <-- Note: This prints only once!
# vmap output: [ 14. 77. 194. 1.]
Notice a few things:
sum_of_squares. We simply applied jax.vmap to it.batch_of_vectors (shape (4, 3)) to the vectorized_sum_of_squares function.[ 14. 77. 194. 1.] contains the sum of squares for each row, matching the manual loop.sum_of_squares likely executed only once during JAX's tracing process for vmap, not once per row as in the Python loop. vmap understands the batching pattern and optimizes the execution., jax.vmap takes the function sum_of_squares (which expects input shape (3,) and produces output shape ()) and transforms it into vectorized_sum_of_squares. This new function understands that if you give it an input with shape (4, 3), it should map the original function over the leading axis (axis 0), effectively applying it to each of the 4 vectors of size 3. It then stacks the scalar results () back together along a new leading axis, producing an output of shape (4,).
By default, vmap maps over the first axis (axis 0) of the input arguments and stacks the results along a new first axis for the output. This default behavior is often exactly what's needed for common batching scenarios in machine learning and numerical computing.
In the following sections, we'll examine how to customize this mapping behavior using in_axes and out_axes for more complex scenarios involving multiple arguments and different batching dimensions.
Was this section helpful?
vmap transformation, detailing its mechanics and usage.© 2026 ApX Machine LearningEngineered with