jax.jitjitjitjitgradjax.gradgrad of grad)jax.value_and_grad)vmapjax.vmapin_axes, out_axes)vmapvmap with jit and gradvmappmapjax.pmapin_axes, out_axes)lax.psum, lax.pmean, etc.)pmap with other Transformationspmapped FunctionsvmapWhile vmap is powerful for adding a single batch dimension, sometimes your data or computation involves multiple levels of batching or mapping. For instance, you might have a batch of sentences, where each sentence is a sequence (batch) of words, and you want to apply a function to each word embedding. Or perhaps you want to compute the interaction between every element in one batch and every element in another batch.
JAX handles these scenarios elegantly through nesting vmap. Since vmap itself is a function transformation, you can apply it multiple times. Applying vmap to a function that has already been transformed by vmap allows you to map over multiple dimensions simultaneously.
Let's reconsider a simple function, like matrix-vector multiplication. Suppose we have a function designed for a single matrix and a single vector:
import jax
import jax.numpy as jnp
def matvec_multiply(matrix, vector):
"""Computes the product of a matrix and a vector."""
# matrix: [M, N], vector: [N] -> result: [M]
return jnp.dot(matrix, vector)
# Example single inputs
matrix = jnp.arange(6.).reshape(2, 3) # Shape (2, 3)
vector = jnp.arange(3.) # Shape (3,)
print("Single matvec:", matvec_multiply(matrix, vector))
# Expected output: [ 5. 14.] (Shape (2,))
Now, imagine you have a batch of matrices and a corresponding batch of vectors, and you want to compute the product for each pair.
# Batch of matrices (3 matrices, each 2x3)
batch_of_matrices = jnp.arange(18.).reshape(3, 2, 3)
# Batch of vectors (3 vectors, each 3)
batch_of_vectors = jnp.arange(9.).reshape(3, 3)
We can achieve this by applying vmap once, mapping over the first dimension (axis 0) of both inputs:
# Map over axis 0 for both arguments
batched_matvec = jax.vmap(matvec_multiply, in_axes=(0, 0))
print("Batched matvec (paired):", batched_matvec(batch_of_matrices, batch_of_vectors))
# Expected output shape: (3, 2) -> 3 results, each of shape (2,)
This works because vmap adds one batch dimension. But what if you want to compute the product of each matrix in the batch with each vector in a potentially different batch? This is like computing an "outer product" at the batch level.
This requires mapping over the batch of matrices and independently mapping over the batch of vectors. This is where nesting vmap comes in.
# Batch of matrices (3 matrices, each 2x3)
batch_of_matrices = jnp.arange(18.).reshape(3, 2, 3)
# Another batch of vectors (say, 4 vectors, each 3)
batch_of_vectors_outer = jnp.arange(12.).reshape(4, 3)
# Goal: Compute matvec for all 3x4 combinations. Result shape should be (3, 4, 2)
# Inner vmap: Map over the vectors (axis 0), keeping the matrix fixed (None)
inner_mapped_matvec = jax.vmap(matvec_multiply, in_axes=(None, 0))
# inner_mapped_matvec takes one matrix and a batch of vectors
# Outer vmap: Map the inner_mapped_matvec over the matrices (axis 0)
# The second argument to inner_mapped_matvec (batch_of_vectors_outer) is constant for this outer map.
# However, JAX often requires explicit handling, so we specify it via in_axes.
# Since batch_of_vectors_outer is already handled by the inner map,
# it doesn't vary with the outer map's iteration.
# But the most common and clearer way is to define it directly:
# Map over matrices (axis 0), then map over vectors (axis 0)
# Read from inside out:
# 1. vmap(matvec_multiply, in_axes=(None, 0)):
# Creates a function that takes ONE matrix and a BATCH of vectors,
# applies matvec_multiply to the matrix and EACH vector.
# Output shape for one matrix and batch of 4 vectors: (4, 2)
# 2. vmap(..., in_axes=(0, None)):
# Takes the function from step 1. Maps this function over a BATCH of matrices (axis 0).
# The second argument (the batch of vectors) is passed entirely to EACH call
# of the inner-mapped function (hence None axis).
# Output shape for batch of 3 matrices and batch of 4 vectors: (3, 4, 2)
pairwise_matvec = jax.vmap(jax.vmap(matvec_multiply, in_axes=(None, 0)), in_axes=(0, None))
result = pairwise_matvec(batch_of_matrices, batch_of_vectors_outer)
print("Shape of pairwise matvec result:", result.shape)
# Expected output: Shape of pairwise matvec result: (3, 4, 2)
# Let's verify one element: result[i, j] should be matvec_multiply(batch_of_matrices[i], batch_of_vectors_outer[j])
manual_calc = matvec_multiply(batch_of_matrices[1], batch_of_vectors_outer[2])
print("Manual calculation for (1, 2):", manual_calc)
print("Result from nested vmap for (1, 2):", result[1, 2])
In the nested vmap call jax.vmap(jax.vmap(f, in_axes=(ax1, ax2)), in_axes=(ax3, ax4)):
vmap(f, in_axes=(ax1, ax2)) creates a function that maps f over axes ax1 and ax2 of its inputs.vmap(..., in_axes=(ax3, ax4)) takes this newly created function and maps it over axes ax3 and ax4 of its inputs.None in in_axes means the corresponding argument is not mapped over at that level; it's broadcast or held constant for that mapping dimension.The pattern vmap(vmap(f, in_axes=(None, 0)), in_axes=(0, None)) (or vice-versa) is extremely common for computing pairwise interactions between elements of two different batches.
Let's compute the Euclidean distance between every point in a set A and every point in a set B.
def euclidean_distance_sq(point_a, point_b):
"""Calculates the squared Euclidean distance between two points."""
# Assumes points are 1D vectors
return jnp.sum((point_a - point_b)**2)
# Batch A: 3 points in 2D space
points_a = jnp.array([[1., 0.],
[0., 1.],
[-1., 0.]]) # Shape (3, 2)
# Batch B: 4 points in 2D space
points_b = jnp.array([[2., 2.],
[-2., 2.],
[2., -2.],
[-2., -2.]]) # Shape (4, 2)
# Goal: Compute a 3x4 matrix where entry (i, j) is distance(points_a[i], points_b[j])
# Inner map: Compute distance between ONE point from A and ALL points in B
# Map over points_b (axis 0), keep point_a fixed (None)
inner_map = jax.vmap(euclidean_distance_sq, in_axes=(None, 0))
# inner_map(points_a[0], points_b) would compute distances from points_a[0] to all points_b
# Outer map: Apply inner_map to EACH point in A
# Map over points_a (axis 0), pass the whole points_b batch (None axis) to the inner map
pairwise_distance_sq = jax.vmap(inner_map, in_axes=(0, None))
# Equivalent direct definition:
pairwise_distance_sq_direct = jax.vmap(
jax.vmap(euclidean_distance_sq, in_axes=(None, 0)), # Map over B for fixed A
in_axes=(0, None) # Map over A, passing full B batch
)
distance_matrix = pairwise_distance_sq(points_a, points_b)
distance_matrix_direct = pairwise_distance_sq_direct(points_a, points_b)
print("Shape of distance matrix:", distance_matrix.shape)
# Expected output: Shape of distance matrix: (3, 4)
print("Distance Matrix (Squared):\n", distance_matrix)
# Verify one element manually: distance_sq(points_a[0], points_b[0])
manual_dist_sq_00 = euclidean_distance_sq(points_a[0], points_b[0]) # (1-2)^2 + (0-2)^2 = 1 + 4 = 5
print("Manual dist_sq[0, 0]:", manual_dist_sq_00)
print("Matrix element [0, 0]:", distance_matrix[0, 0])
# Expected output: Manual dist_sq[0, 0]: 5.0, Matrix element [0, 0]: 5.0
This nested vmap cleanly expresses the desired pairwise computation without manual broadcasting or loops. JAX compiles this nested structure efficiently.
Nested vmaps compose with other transformations like jit and grad. You can JIT-compile a function that uses nested vmaps for performance, or compute gradients through them.
# JIT-compiling the pairwise distance function
jitted_pairwise_distance_sq = jax.jit(pairwise_distance_sq_direct)
distance_matrix_jitted = jitted_pairwise_distance_sq(points_a, points_b)
print("JITted Distance Matrix (Squared):\n", distance_matrix_jitted)
# Example: Gradient of the sum of pairwise distances w.r.t points_a
def sum_of_distances(p_a, p_b):
# Calculate all pairwise distances
dist_matrix = pairwise_distance_sq_direct(p_a, p_b)
# Return the sum
return jnp.sum(dist_matrix)
# Compute gradient w.r.t the first argument (points_a)
grad_sum_dist_wrt_a = jax.grad(sum_of_distances, argnums=0)
gradients_a = grad_sum_dist_wrt_a(points_a, points_b)
print("Shape of gradients w.r.t points_a:", gradients_a.shape)
# Expected output: Shape of gradients w.r.t points_a: (3, 2) (same shape as points_a)
print("Gradients w.r.t points_a:\n", gradients_a)
Nesting vmap provides a powerful way to handle complex batching structures in a functional and composable manner. While the in_axes specification can seem intricate at first, understanding that each vmap introduces one level of mapping helps in constructing the desired nested transformation. The pattern for pairwise computations, in particular, is a valuable tool in many scientific computing and machine learning tasks.
Was this section helpful?
vmap, JAX core contributors, 2024 (JAX Project) - Provides an introduction to vmap, illustrating its use for batching computations and explaining the in_axes parameter for handling multiple input dimensions, including nesting.vmap, enable efficient and composable numerical computing.© 2026 ApX Machine LearningEngineered with