While vmap
is powerful for adding a single batch dimension, sometimes your data or computation involves multiple levels of batching or mapping. For instance, you might have a batch of sentences, where each sentence is a sequence (batch) of words, and you want to apply a function to each word embedding. Or perhaps you want to compute the interaction between every element in one batch and every element in another batch.
JAX handles these scenarios elegantly through nesting vmap
. Since vmap
itself is a function transformation, you can apply it multiple times. Applying vmap
to a function that has already been transformed by vmap
allows you to map over multiple dimensions simultaneously.
Let's reconsider a simple function, like matrix-vector multiplication. Suppose we have a function designed for a single matrix and a single vector:
import jax
import jax.numpy as jnp
def matvec_multiply(matrix, vector):
"""Computes the product of a matrix and a vector."""
# matrix: [M, N], vector: [N] -> result: [M]
return jnp.dot(matrix, vector)
# Example single inputs
matrix = jnp.arange(6.).reshape(2, 3) # Shape (2, 3)
vector = jnp.arange(3.) # Shape (3,)
print("Single matvec:", matvec_multiply(matrix, vector))
# Expected output: [ 5. 14.] (Shape (2,))
Now, imagine you have a batch of matrices and a corresponding batch of vectors, and you want to compute the product for each pair.
# Batch of matrices (3 matrices, each 2x3)
batch_of_matrices = jnp.arange(18.).reshape(3, 2, 3)
# Batch of vectors (3 vectors, each 3)
batch_of_vectors = jnp.arange(9.).reshape(3, 3)
We can achieve this by applying vmap
once, mapping over the first dimension (axis 0) of both inputs:
# Map over axis 0 for both arguments
batched_matvec = jax.vmap(matvec_multiply, in_axes=(0, 0))
print("Batched matvec (paired):", batched_matvec(batch_of_matrices, batch_of_vectors))
# Expected output shape: (3, 2) -> 3 results, each of shape (2,)
This works because vmap
adds one batch dimension. But what if you want to compute the product of each matrix in the batch with each vector in a potentially different batch? This is like computing an "outer product" at the batch level.
This requires mapping over the batch of matrices and independently mapping over the batch of vectors. This is where nesting vmap
comes in.
# Batch of matrices (3 matrices, each 2x3)
batch_of_matrices = jnp.arange(18.).reshape(3, 2, 3)
# Another batch of vectors (say, 4 vectors, each 3)
batch_of_vectors_outer = jnp.arange(12.).reshape(4, 3)
# Goal: Compute matvec for all 3x4 combinations. Result shape should be (3, 4, 2)
# Inner vmap: Map over the vectors (axis 0), keeping the matrix fixed (None)
inner_mapped_matvec = jax.vmap(matvec_multiply, in_axes=(None, 0))
# inner_mapped_matvec takes one matrix and a batch of vectors
# Outer vmap: Map the inner_mapped_matvec over the matrices (axis 0)
# The second argument to inner_mapped_matvec (batch_of_vectors_outer) is constant for this outer map.
# However, JAX often requires explicit handling, so we specify it via in_axes.
# Since batch_of_vectors_outer is already handled by the inner map,
# it doesn't vary with the outer map's iteration.
# But the most common and clearer way is to define it directly:
# Map over matrices (axis 0), then map over vectors (axis 0)
# Read from inside out:
# 1. vmap(matvec_multiply, in_axes=(None, 0)):
# Creates a function that takes ONE matrix and a BATCH of vectors,
# applies matvec_multiply to the matrix and EACH vector.
# Output shape for one matrix and batch of 4 vectors: (4, 2)
# 2. vmap(..., in_axes=(0, None)):
# Takes the function from step 1. Maps this function over a BATCH of matrices (axis 0).
# The second argument (the batch of vectors) is passed entirely to EACH call
# of the inner-mapped function (hence None axis).
# Output shape for batch of 3 matrices and batch of 4 vectors: (3, 4, 2)
pairwise_matvec = jax.vmap(jax.vmap(matvec_multiply, in_axes=(None, 0)), in_axes=(0, None))
result = pairwise_matvec(batch_of_matrices, batch_of_vectors_outer)
print("Shape of pairwise matvec result:", result.shape)
# Expected output: Shape of pairwise matvec result: (3, 4, 2)
# Let's verify one element: result[i, j] should be matvec_multiply(batch_of_matrices[i], batch_of_vectors_outer[j])
manual_calc = matvec_multiply(batch_of_matrices[1], batch_of_vectors_outer[2])
print("Manual calculation for (1, 2):", manual_calc)
print("Result from nested vmap for (1, 2):", result[1, 2])
In the nested vmap
call jax.vmap(jax.vmap(f, in_axes=(ax1, ax2)), in_axes=(ax3, ax4))
:
vmap(f, in_axes=(ax1, ax2))
creates a function that maps f
over axes ax1
and ax2
of its inputs.vmap(..., in_axes=(ax3, ax4))
takes this newly created function and maps it over axes ax3
and ax4
of its inputs.None
in in_axes
means the corresponding argument is not mapped over at that level; it's broadcast or held constant for that mapping dimension.The pattern vmap(vmap(f, in_axes=(None, 0)), in_axes=(0, None))
(or vice-versa) is extremely common for computing pairwise interactions between elements of two different batches.
Let's compute the Euclidean distance between every point in a set A
and every point in a set B
.
def euclidean_distance_sq(point_a, point_b):
"""Calculates the squared Euclidean distance between two points."""
# Assumes points are 1D vectors
return jnp.sum((point_a - point_b)**2)
# Batch A: 3 points in 2D space
points_a = jnp.array([[1., 0.],
[0., 1.],
[-1., 0.]]) # Shape (3, 2)
# Batch B: 4 points in 2D space
points_b = jnp.array([[2., 2.],
[-2., 2.],
[2., -2.],
[-2., -2.]]) # Shape (4, 2)
# Goal: Compute a 3x4 matrix where entry (i, j) is distance(points_a[i], points_b[j])
# Inner map: Compute distance between ONE point from A and ALL points in B
# Map over points_b (axis 0), keep point_a fixed (None)
inner_map = jax.vmap(euclidean_distance_sq, in_axes=(None, 0))
# inner_map(points_a[0], points_b) would compute distances from points_a[0] to all points_b
# Outer map: Apply inner_map to EACH point in A
# Map over points_a (axis 0), pass the whole points_b batch (None axis) to the inner map
pairwise_distance_sq = jax.vmap(inner_map, in_axes=(0, None))
# Equivalent direct definition:
pairwise_distance_sq_direct = jax.vmap(
jax.vmap(euclidean_distance_sq, in_axes=(None, 0)), # Map over B for fixed A
in_axes=(0, None) # Map over A, passing full B batch
)
distance_matrix = pairwise_distance_sq(points_a, points_b)
distance_matrix_direct = pairwise_distance_sq_direct(points_a, points_b)
print("Shape of distance matrix:", distance_matrix.shape)
# Expected output: Shape of distance matrix: (3, 4)
print("Distance Matrix (Squared):\n", distance_matrix)
# Verify one element manually: distance_sq(points_a[0], points_b[0])
manual_dist_sq_00 = euclidean_distance_sq(points_a[0], points_b[0]) # (1-2)^2 + (0-2)^2 = 1 + 4 = 5
print("Manual dist_sq[0, 0]:", manual_dist_sq_00)
print("Matrix element [0, 0]:", distance_matrix[0, 0])
# Expected output: Manual dist_sq[0, 0]: 5.0, Matrix element [0, 0]: 5.0
This nested vmap
cleanly expresses the desired pairwise computation without manual broadcasting or loops. JAX compiles this nested structure efficiently.
Nested vmap
s compose seamlessly with other transformations like jit
and grad
. You can JIT-compile a function that uses nested vmap
s for performance, or compute gradients through them.
# JIT-compiling the pairwise distance function
jitted_pairwise_distance_sq = jax.jit(pairwise_distance_sq_direct)
distance_matrix_jitted = jitted_pairwise_distance_sq(points_a, points_b)
print("JITted Distance Matrix (Squared):\n", distance_matrix_jitted)
# Example: Gradient of the sum of pairwise distances w.r.t points_a
def sum_of_distances(p_a, p_b):
# Calculate all pairwise distances
dist_matrix = pairwise_distance_sq_direct(p_a, p_b)
# Return the sum
return jnp.sum(dist_matrix)
# Compute gradient w.r.t the first argument (points_a)
grad_sum_dist_wrt_a = jax.grad(sum_of_distances, argnums=0)
gradients_a = grad_sum_dist_wrt_a(points_a, points_b)
print("Shape of gradients w.r.t points_a:", gradients_a.shape)
# Expected output: Shape of gradients w.r.t points_a: (3, 2) (same shape as points_a)
print("Gradients w.r.t points_a:\n", gradients_a)
Nesting vmap
provides a powerful way to handle complex batching structures in a functional and composable manner. While the in_axes
specification can seem intricate at first, understanding that each vmap
introduces one level of mapping helps in constructing the desired nested transformation. The pattern for pairwise computations, in particular, is a valuable tool in many scientific computing and machine learning tasks.
© 2025 ApX Machine Learning